From: @peixu_ren Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -18,27 +18,29 @@ Distributions are the high-level components used to construct the probabilistic | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from .transformed_distribution import TransformedDistribution | from .transformed_distribution import TransformedDistribution | ||||
| from .normal import Normal | |||||
| from .bernoulli import Bernoulli | from .bernoulli import Bernoulli | ||||
| from .categorical import Categorical | |||||
| from .cauchy import Cauchy | |||||
| from .exponential import Exponential | from .exponential import Exponential | ||||
| from .uniform import Uniform | |||||
| from .geometric import Geometric | from .geometric import Geometric | ||||
| from .categorical import Categorical | |||||
| from .log_normal import LogNormal | |||||
| from .logistic import Logistic | |||||
| from .gumbel import Gumbel | from .gumbel import Gumbel | ||||
| from .cauchy import Cauchy | |||||
| from .logistic import Logistic | |||||
| from .log_normal import LogNormal | |||||
| from .normal import Normal | |||||
| from .poisson import Poisson | |||||
| from .uniform import Uniform | |||||
| __all__ = ['Distribution', | __all__ = ['Distribution', | ||||
| 'TransformedDistribution', | 'TransformedDistribution', | ||||
| 'Normal', | |||||
| 'Bernoulli', | 'Bernoulli', | ||||
| 'Exponential', | |||||
| 'Uniform', | |||||
| 'Categorical', | 'Categorical', | ||||
| 'Cauchy', | |||||
| 'Exponential', | |||||
| 'Geometric', | 'Geometric', | ||||
| 'LogNormal', | |||||
| 'Logistic', | |||||
| 'Gumbel', | 'Gumbel', | ||||
| 'Cauchy', | |||||
| 'Logistic', | |||||
| 'LogNormal', | |||||
| 'Normal', | |||||
| 'Poisson', | |||||
| 'Uniform', | |||||
| ] | ] | ||||
| @@ -0,0 +1,255 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Poisson Distribution""" | |||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| import mindspore.nn as nn | |||||
| from mindspore._checkparam import Validator | |||||
| from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | |||||
| from ._utils.utils import check_greater_zero | |||||
| from ._utils.custom_ops import exp_generic, log_generic | |||||
| class Poisson(Distribution): | |||||
| """ | |||||
| Poisson Distribution. | |||||
| Args: | |||||
| rate (float, list, numpy.ndarray, Tensor, Parameter): The rate of the Poisson distribution.. | |||||
| seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. | |||||
| dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. | |||||
| name (str): The name of the distribution. Default: 'Poisson'. | |||||
| Note: | |||||
| `rate` must be strictly greater than 0. | |||||
| `dist_spec_args` is `rate`. | |||||
| Examples: | |||||
| >>> # To initialize an Poisson distribution of the rate 0.5. | |||||
| >>> import mindspore.nn.probability.distribution as msd | |||||
| >>> p = msd.Poisson(0.5, dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # The following creates two independent Poisson distributions. | |||||
| >>> p = msd.Poisson([0.5, 0.5], dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # An Poisson distribution can be initilized without arguments. | |||||
| >>> # In this case, `rate` must be passed in through `args` during function calls. | |||||
| >>> p = msd.Poisson(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # To use an Poisson distribution in a network. | |||||
| >>> class net(Cell): | |||||
| ... def __init__(self): | |||||
| ... super(net, self).__init__(): | |||||
| ... self.p1 = msd.Poisson(0.5, dtype=mstype.float32) | |||||
| ... self.p2 = msd.Poisson(dtype=mstype.float32) | |||||
| ... | |||||
| ... # All the following calls in construct are valid. | |||||
| ... def construct(self, value, rate_b, rate_a): | |||||
| ... | |||||
| ... # Private interfaces of probability functions corresponding to public interfaces, including | |||||
| ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. | |||||
| ... # Args: | |||||
| ... # value (Tensor): the value to be evaluated. | |||||
| ... # rate (Tensor): the rate of the distribution. Default: self.rate. | |||||
| ... | |||||
| ... # Examples of `prob`. | |||||
| ... # Similar calls can be made to other probability functions | |||||
| ... # by replacing `prob` by the name of the function. | |||||
| ... ans = self.p1.prob(value) | |||||
| ... # Evaluate with respect to distribution b. | |||||
| ... ans = self.p1.prob(value, rate_b) | |||||
| ... # `rate` must be passed in during function calls. | |||||
| ... ans = self.p2.prob(value, rate_a) | |||||
| ... | |||||
| ... | |||||
| ... # Functions `mean`, `sd`, and 'var' have the same arguments as follows. | |||||
| ... # Args: | |||||
| ... # rate (Tensor): the rate of the distribution. Default: self.rate. | |||||
| ... | |||||
| ... # Examples of `mean`. `sd`, `var`, and `entropy` are similar. | |||||
| ... ans = self.p1.mean() # return 2 | |||||
| ... ans = self.p1.mean(rate_b) # return 1 / rate_b | |||||
| ... # `rate` must be passed in during function calls. | |||||
| ... ans = self.p2.mean(rate_a) | |||||
| ... | |||||
| ... | |||||
| ... # Examples of `sample`. | |||||
| ... # Args: | |||||
| ... # shape (tuple): the shape of the sample. Default: () | |||||
| ... # probs1 (Tensor): the rate of the distribution. Default: self.rate. | |||||
| ... ans = self.p1.sample() | |||||
| ... ans = self.p1.sample((2,3)) | |||||
| ... ans = self.p1.sample((2,3), rate_b) | |||||
| ... ans = self.p2.sample((2,3), rate_a) | |||||
| """ | |||||
| def __init__(self, | |||||
| rate=None, | |||||
| seed=None, | |||||
| dtype=mstype.float32, | |||||
| name="Poisson"): | |||||
| """ | |||||
| Constructor of Poisson. | |||||
| """ | |||||
| param = dict(locals()) | |||||
| param['param_dict'] = {'rate': rate} | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Poisson, self).__init__(seed, dtype, name, param) | |||||
| self._rate = self._add_parameter(rate, 'rate') | |||||
| if self.rate is not None: | |||||
| check_greater_zero(self.rate, 'rate') | |||||
| # ops needed for the class | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.squeeze = P.Squeeze(0) | |||||
| self.cast = P.Cast() | |||||
| self.floor = P.Floor() | |||||
| self.dtypeop = P.DType() | |||||
| self.shape = P.Shape() | |||||
| self.fill = P.Fill() | |||||
| self.less = P.Less() | |||||
| self.equal = P.Equal() | |||||
| self.select = P.Select() | |||||
| self.lgamma = nn.LGamma() | |||||
| self.igamma = nn.IGamma() | |||||
| self.poisson = C.poisson | |||||
| def extend_repr(self): | |||||
| if self.is_scalar_batch: | |||||
| s = f'rate = {self.rate}' | |||||
| else: | |||||
| s = f'batch_shape = {self._broadcast_shape}' | |||||
| return s | |||||
| @property | |||||
| def rate(self): | |||||
| """ | |||||
| Return `rate` of the distribution. | |||||
| """ | |||||
| return self._rate | |||||
| def _get_dist_type(self): | |||||
| return "Poisson" | |||||
| def _get_dist_args(self, rate=None): | |||||
| if rate is not None: | |||||
| self.checktensor(rate, 'rate') | |||||
| else: | |||||
| rate = self.rate | |||||
| return (rate,) | |||||
| def _mean(self, rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MEAN(POISSON) = \lambda. | |||||
| """ | |||||
| rate = self._check_param_type(rate) | |||||
| return rate | |||||
| def _mode(self, rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MODE(POISSON) = \lfloor{\lambda}. | |||||
| """ | |||||
| rate = self._check_param_type(rate) | |||||
| return self.floor(rate) | |||||
| def _var(self, rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| VAR(POISSON) = \lambda. | |||||
| """ | |||||
| rate = self._check_param_type(rate) | |||||
| return rate | |||||
| def _log_prob(self, value, rate=None): | |||||
| r""" | |||||
| Log probability density function of Poisson distributions. | |||||
| Args: | |||||
| Args: | |||||
| value (Tensor): The value to be evaluated. | |||||
| rate (Tensor): The rate of the distribution. Default: self.rate. | |||||
| Note: | |||||
| `value` must be greater or equal to zero. | |||||
| .. math:: | |||||
| log_pdf(x) = x * \log(\lambda) - \lambda - \log(\Gamma(x)) if x >= 0 else -inf | |||||
| """ | |||||
| value = self._check_value(value, "value") | |||||
| value = self.cast(value, self.dtype) | |||||
| rate = self._check_param_type(rate) | |||||
| log_rate = self.log(rate) | |||||
| zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) | |||||
| inf = self.fill(self.dtypeop(value), self.shape(value), np.inf) | |||||
| safe_x = self.select(self.less(value, zeros), zeros, value) | |||||
| y = log_rate * safe_x - self.lgamma(safe_x + 1.) | |||||
| comp = self.equal(value, safe_x) | |||||
| log_unnormalized_prob = self.select(comp, y, -inf) | |||||
| log_normalization = self.exp(log_rate) | |||||
| return log_unnormalized_prob - log_normalization | |||||
| def _cdf(self, value, rate=None): | |||||
| r""" | |||||
| Cumulative distribution function (cdf) of Poisson distributions. | |||||
| Args: | |||||
| value (Tensor): The value to be evaluated. | |||||
| rate (Tensor): The rate of the distribution. Default: self.rate. | |||||
| Note: | |||||
| `value` must be greater or equal to zero. | |||||
| .. math:: | |||||
| cdf(x) = \Gamma(x + 1) if x >= 0 else 0 | |||||
| """ | |||||
| value = self._check_value(value, 'value') | |||||
| value = self.cast(value, self.dtype) | |||||
| rate = self._check_param_type(rate) | |||||
| zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) | |||||
| comp = self.less(value, zeros) | |||||
| safe_x = self.select(comp, zeros, value) | |||||
| cdf = 1. - self.igamma(1. + safe_x, rate) | |||||
| return self.select(comp, zeros, cdf) | |||||
| def _sample(self, shape=(), rate=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| shape (tuple): The shape of the sample. Default: (). | |||||
| rate (Tensor): The rate of the distribution. Default: self.rate. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| shape = self.checktuple(shape, 'shape') | |||||
| rate = self._check_param_type(rate) | |||||
| origin_shape = shape + self.shape(rate) | |||||
| if origin_shape == (): | |||||
| sample_shape = (1,) | |||||
| else: | |||||
| sample_shape = origin_shape | |||||
| sample_poisson = self.poisson(sample_shape, rate, self.seed) | |||||
| value = self.cast(sample_poisson, self.dtype) | |||||
| if origin_shape == (): | |||||
| value = self.squeeze(value) | |||||
| return value | |||||
| @@ -0,0 +1,210 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for Poisson distribution""" | |||||
| import numpy as np | |||||
| from scipy import stats | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.distribution as msd | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Prob(nn.Cell): | |||||
| """ | |||||
| Test class: probability of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Prob, self).__init__() | |||||
| self.p = msd.Poisson([0.5], dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.prob(x_) | |||||
| def test_pdf(): | |||||
| """ | |||||
| Test pdf. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_pdf = poisson_benchmark.pmf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| pdf = Prob() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = pdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogProb, self).__init__() | |||||
| self.p = msd.Poisson(0.5, dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.log_prob(x_) | |||||
| def test_log_likelihood(): | |||||
| """ | |||||
| Test log_pdf. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_logpdf = poisson_benchmark.logpmf([1.0, 2.0]).astype(np.float32) | |||||
| logprob = LogProb() | |||||
| x_ = Tensor(np.array([1.0, 2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd/mode of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Basics, self).__init__() | |||||
| self.p = msd.Poisson([1.44], dtype=dtype.float32) | |||||
| def construct(self): | |||||
| return self.p.mean(), self.p.sd(), self.p.mode() | |||||
| def test_basics(): | |||||
| """ | |||||
| Test mean/standard/mode deviation. | |||||
| """ | |||||
| basics = Basics() | |||||
| mean, sd, mode = basics() | |||||
| expect_mean = 1.44 | |||||
| expect_sd = 1.2 | |||||
| expect_mode = 1 | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: sample of Poisson distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Sampling, self).__init__() | |||||
| self.p = msd.Poisson([[1.0], [0.5]], seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| def construct(self, rate=None): | |||||
| return self.p.sample(self.shape, rate) | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| seed = 10 | |||||
| rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) | |||||
| sample = Sampling(shape, seed=seed) | |||||
| output = sample(rate) | |||||
| assert output.shape == (2, 3, 3) | |||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.p = msd.Poisson([0.5], dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.cdf(x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_cdf = poisson_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| cdf = CDF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = cdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| class LogCDF(nn.Cell): | |||||
| """ | |||||
| Test class: log_cdf of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.p = msd.Poisson([0.5], dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.log_cdf(x_) | |||||
| def test_log_cdf(): | |||||
| """ | |||||
| Test log_cdf. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_logcdf = poisson_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) | |||||
| logcdf = LogCDF() | |||||
| x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logcdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| class SF(nn.Cell): | |||||
| """ | |||||
| Test class: survival function of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.p = msd.Poisson(0.5, dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.survival_function(x_) | |||||
| def test_survival(): | |||||
| """ | |||||
| Test survival function. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_survival = poisson_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| survival = SF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = survival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survival function of Poisson distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.p = msd.Poisson(0.5, dtype=dtype.float32) | |||||
| def construct(self, x_): | |||||
| return self.p.log_survival(x_) | |||||
| def test_log_survival(): | |||||
| """ | |||||
| Test log survival function. | |||||
| """ | |||||
| poisson_benchmark = stats.poisson(mu=0.5) | |||||
| expect_logsurvival = poisson_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| logsurvival = LogSF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logsurvival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() | |||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.probability.distribution.Poisson. | |||||
| """ | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.distribution as msd | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_arguments(): | |||||
| """ | |||||
| Args passing during initialization. | |||||
| """ | |||||
| p = msd.Poisson() | |||||
| assert isinstance(p, msd.Distribution) | |||||
| p = msd.Poisson([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) | |||||
| assert isinstance(p, msd.Distribution) | |||||
| def test_type(): | |||||
| with pytest.raises(TypeError): | |||||
| msd.Poisson([0.1], dtype=dtype.bool_) | |||||
| def test_name(): | |||||
| with pytest.raises(TypeError): | |||||
| msd.Poisson([0.1], name=1.0) | |||||
| def test_seed(): | |||||
| with pytest.raises(TypeError): | |||||
| msd.Poisson([0.1], seed='seed') | |||||
| def test_rate(): | |||||
| """ | |||||
| Invalid rate. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| msd.Poisson([-0.1], dtype=dtype.float32) | |||||
| with pytest.raises(ValueError): | |||||
| msd.Poisson([0.0], dtype=dtype.float32) | |||||
| class PoissonProb(nn.Cell): | |||||
| """ | |||||
| Poisson distribution: initialize with rate. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(PoissonProb, self).__init__() | |||||
| self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) | |||||
| def construct(self, value): | |||||
| prob = self.p.prob(value) | |||||
| log_prob = self.p.log_prob(value) | |||||
| cdf = self.p.cdf(value) | |||||
| log_cdf = self.p.log_cdf(value) | |||||
| sf = self.p.survival_function(value) | |||||
| log_sf = self.p.log_survival(value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_poisson_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = PoissonProb() | |||||
| value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class PoissonProb1(nn.Cell): | |||||
| """ | |||||
| Poisson distribution: initialize without rate. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(PoissonProb1, self).__init__() | |||||
| self.p = msd.Poisson(dtype=dtype.float32) | |||||
| def construct(self, value, rate): | |||||
| prob = self.p.prob(value, rate) | |||||
| log_prob = self.p.log_prob(value, rate) | |||||
| cdf = self.p.cdf(value, rate) | |||||
| log_cdf = self.p.log_cdf(value, rate) | |||||
| sf = self.p.survival_function(value, rate) | |||||
| log_sf = self.p.log_survival(value, rate) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_poisson_prob1(): | |||||
| """ | |||||
| Test probability functions: passing value/rate through construct. | |||||
| """ | |||||
| net = PoissonProb1() | |||||
| value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) | |||||
| rate = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) | |||||
| ans = net(value, rate) | |||||
| assert isinstance(ans, Tensor) | |||||
| class PoissonBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd/var/mode function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(PoissonBasics, self).__init__() | |||||
| self.p = msd.Poisson([2.3, 2.5], dtype=dtype.float32) | |||||
| def construct(self): | |||||
| mean = self.p.mean() | |||||
| sd = self.p.sd() | |||||
| var = self.p.var() | |||||
| return mean + sd + var | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/var/mode functionality of Poisson distribution. | |||||
| """ | |||||
| net = PoissonBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| class PoissonConstruct(nn.Cell): | |||||
| """ | |||||
| Poisson distribution: going through construct. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(PoissonConstruct, self).__init__() | |||||
| self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) | |||||
| self.p1 = msd.Poisson(dtype=dtype.float32) | |||||
| def construct(self, value, rate): | |||||
| prob = self.p('prob', value) | |||||
| prob1 = self.p('prob', value, rate) | |||||
| prob2 = self.p1('prob', value, rate) | |||||
| return prob + prob1 + prob2 | |||||
| def test_poisson_construct(): | |||||
| """ | |||||
| Test probability function going through construct. | |||||
| """ | |||||
| net = PoissonConstruct() | |||||
| value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) | |||||
| probs = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) | |||||
| ans = net(value, probs) | |||||
| assert isinstance(ans, Tensor) | |||||