From 01f5da0a14ce9bfde2c3f03597efbc38157509ae Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Sun, 15 Nov 2020 22:14:52 -0500 Subject: [PATCH] Add Poisson distribution --- .../nn/probability/distribution/__init__.py | 26 +- .../nn/probability/distribution/poisson.py | 255 ++++++++++++++++++ .../probability/distribution/test_poisson.py | 210 +++++++++++++++ .../probability/distribution/test_poisson.py | 154 +++++++++++ 4 files changed, 633 insertions(+), 12 deletions(-) create mode 100644 mindspore/nn/probability/distribution/poisson.py create mode 100644 tests/st/probability/distribution/test_poisson.py create mode 100644 tests/ut/python/nn/probability/distribution/test_poisson.py diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py index a2e49e28e1..0d35d2ed8a 100644 --- a/mindspore/nn/probability/distribution/__init__.py +++ b/mindspore/nn/probability/distribution/__init__.py @@ -18,27 +18,29 @@ Distributions are the high-level components used to construct the probabilistic from .distribution import Distribution from .transformed_distribution import TransformedDistribution -from .normal import Normal from .bernoulli import Bernoulli +from .categorical import Categorical +from .cauchy import Cauchy from .exponential import Exponential -from .uniform import Uniform from .geometric import Geometric -from .categorical import Categorical -from .log_normal import LogNormal -from .logistic import Logistic from .gumbel import Gumbel -from .cauchy import Cauchy +from .logistic import Logistic +from .log_normal import LogNormal +from .normal import Normal +from .poisson import Poisson +from .uniform import Uniform __all__ = ['Distribution', 'TransformedDistribution', - 'Normal', 'Bernoulli', - 'Exponential', - 'Uniform', 'Categorical', + 'Cauchy', + 'Exponential', 'Geometric', - 'LogNormal', - 'Logistic', 'Gumbel', - 'Cauchy', + 'Logistic', + 'LogNormal', + 'Normal', + 'Poisson', + 'Uniform', ] diff --git a/mindspore/nn/probability/distribution/poisson.py b/mindspore/nn/probability/distribution/poisson.py new file mode 100644 index 0000000000..f9f77cfd17 --- /dev/null +++ b/mindspore/nn/probability/distribution/poisson.py @@ -0,0 +1,255 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Poisson Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import composite as C +import mindspore.nn as nn +from mindspore._checkparam import Validator +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import check_greater_zero +from ._utils.custom_ops import exp_generic, log_generic + + +class Poisson(Distribution): + """ + Poisson Distribution. + + Args: + rate (float, list, numpy.ndarray, Tensor, Parameter): The rate of the Poisson distribution.. + seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. + dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. + name (str): The name of the distribution. Default: 'Poisson'. + + Note: + `rate` must be strictly greater than 0. + `dist_spec_args` is `rate`. + + Examples: + >>> # To initialize an Poisson distribution of the rate 0.5. + >>> import mindspore.nn.probability.distribution as msd + >>> p = msd.Poisson(0.5, dtype=mstype.float32) + >>> + >>> # The following creates two independent Poisson distributions. + >>> p = msd.Poisson([0.5, 0.5], dtype=mstype.float32) + >>> + >>> # An Poisson distribution can be initilized without arguments. + >>> # In this case, `rate` must be passed in through `args` during function calls. + >>> p = msd.Poisson(dtype=mstype.float32) + >>> + >>> # To use an Poisson distribution in a network. + >>> class net(Cell): + ... def __init__(self): + ... super(net, self).__init__(): + ... self.p1 = msd.Poisson(0.5, dtype=mstype.float32) + ... self.p2 = msd.Poisson(dtype=mstype.float32) + ... + ... # All the following calls in construct are valid. + ... def construct(self, value, rate_b, rate_a): + ... + ... # Private interfaces of probability functions corresponding to public interfaces, including + ... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows. + ... # Args: + ... # value (Tensor): the value to be evaluated. + ... # rate (Tensor): the rate of the distribution. Default: self.rate. + ... + ... # Examples of `prob`. + ... # Similar calls can be made to other probability functions + ... # by replacing `prob` by the name of the function. + ... ans = self.p1.prob(value) + ... # Evaluate with respect to distribution b. + ... ans = self.p1.prob(value, rate_b) + ... # `rate` must be passed in during function calls. + ... ans = self.p2.prob(value, rate_a) + ... + ... + ... # Functions `mean`, `sd`, and 'var' have the same arguments as follows. + ... # Args: + ... # rate (Tensor): the rate of the distribution. Default: self.rate. + ... + ... # Examples of `mean`. `sd`, `var`, and `entropy` are similar. + ... ans = self.p1.mean() # return 2 + ... ans = self.p1.mean(rate_b) # return 1 / rate_b + ... # `rate` must be passed in during function calls. + ... ans = self.p2.mean(rate_a) + ... + ... + ... # Examples of `sample`. + ... # Args: + ... # shape (tuple): the shape of the sample. Default: () + ... # probs1 (Tensor): the rate of the distribution. Default: self.rate. + ... ans = self.p1.sample() + ... ans = self.p1.sample((2,3)) + ... ans = self.p1.sample((2,3), rate_b) + ... ans = self.p2.sample((2,3), rate_a) + """ + + def __init__(self, + rate=None, + seed=None, + dtype=mstype.float32, + name="Poisson"): + """ + Constructor of Poisson. + """ + param = dict(locals()) + param['param_dict'] = {'rate': rate} + valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) + super(Poisson, self).__init__(seed, dtype, name, param) + + self._rate = self._add_parameter(rate, 'rate') + if self.rate is not None: + check_greater_zero(self.rate, 'rate') + + # ops needed for the class + self.exp = exp_generic + self.log = log_generic + self.squeeze = P.Squeeze(0) + self.cast = P.Cast() + self.floor = P.Floor() + self.dtypeop = P.DType() + self.shape = P.Shape() + self.fill = P.Fill() + self.less = P.Less() + self.equal = P.Equal() + self.select = P.Select() + self.lgamma = nn.LGamma() + self.igamma = nn.IGamma() + self.poisson = C.poisson + + def extend_repr(self): + if self.is_scalar_batch: + s = f'rate = {self.rate}' + else: + s = f'batch_shape = {self._broadcast_shape}' + return s + + @property + def rate(self): + """ + Return `rate` of the distribution. + """ + return self._rate + + def _get_dist_type(self): + return "Poisson" + + def _get_dist_args(self, rate=None): + if rate is not None: + self.checktensor(rate, 'rate') + else: + rate = self.rate + return (rate,) + + def _mean(self, rate=None): + r""" + .. math:: + MEAN(POISSON) = \lambda. + """ + rate = self._check_param_type(rate) + return rate + + def _mode(self, rate=None): + r""" + .. math:: + MODE(POISSON) = \lfloor{\lambda}. + """ + rate = self._check_param_type(rate) + return self.floor(rate) + + def _var(self, rate=None): + r""" + .. math:: + VAR(POISSON) = \lambda. + """ + rate = self._check_param_type(rate) + return rate + + def _log_prob(self, value, rate=None): + r""" + Log probability density function of Poisson distributions. + + Args: + Args: + value (Tensor): The value to be evaluated. + rate (Tensor): The rate of the distribution. Default: self.rate. + + Note: + `value` must be greater or equal to zero. + + .. math:: + log_pdf(x) = x * \log(\lambda) - \lambda - \log(\Gamma(x)) if x >= 0 else -inf + """ + value = self._check_value(value, "value") + value = self.cast(value, self.dtype) + rate = self._check_param_type(rate) + log_rate = self.log(rate) + zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) + inf = self.fill(self.dtypeop(value), self.shape(value), np.inf) + safe_x = self.select(self.less(value, zeros), zeros, value) + y = log_rate * safe_x - self.lgamma(safe_x + 1.) + comp = self.equal(value, safe_x) + log_unnormalized_prob = self.select(comp, y, -inf) + log_normalization = self.exp(log_rate) + return log_unnormalized_prob - log_normalization + + def _cdf(self, value, rate=None): + r""" + Cumulative distribution function (cdf) of Poisson distributions. + + Args: + value (Tensor): The value to be evaluated. + rate (Tensor): The rate of the distribution. Default: self.rate. + + Note: + `value` must be greater or equal to zero. + + .. math:: + cdf(x) = \Gamma(x + 1) if x >= 0 else 0 + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + rate = self._check_param_type(rate) + zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) + comp = self.less(value, zeros) + safe_x = self.select(comp, zeros, value) + cdf = 1. - self.igamma(1. + safe_x, rate) + return self.select(comp, zeros, cdf) + + def _sample(self, shape=(), rate=None): + """ + Sampling. + + Args: + shape (tuple): The shape of the sample. Default: (). + rate (Tensor): The rate of the distribution. Default: self.rate. + + Returns: + Tensor, shape is shape + batch_shape. + """ + shape = self.checktuple(shape, 'shape') + rate = self._check_param_type(rate) + origin_shape = shape + self.shape(rate) + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape + sample_poisson = self.poisson(sample_shape, rate, self.seed) + value = self.cast(sample_poisson, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/tests/st/probability/distribution/test_poisson.py b/tests/st/probability/distribution/test_poisson.py new file mode 100644 index 0000000000..40e7358cc0 --- /dev/null +++ b/tests/st/probability/distribution/test_poisson.py @@ -0,0 +1,210 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Poisson distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Poisson distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.p = msd.Poisson([0.5], dtype=dtype.float32) + + def construct(self, x_): + return self.p.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_pdf = poisson_benchmark.pmf([-1.0, 0.0, 1.0]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Poisson distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.p = msd.Poisson(0.5, dtype=dtype.float32) + + def construct(self, x_): + return self.p.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_logpdf = poisson_benchmark.logpmf([1.0, 2.0]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Poisson distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.p = msd.Poisson([1.44], dtype=dtype.float32) + + def construct(self): + return self.p.mean(), self.p.sd(), self.p.mode() + +def test_basics(): + """ + Test mean/standard/mode deviation. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = 1.44 + expect_sd = 1.2 + expect_mode = 1 + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Poisson distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.p = msd.Poisson([[1.0], [0.5]], seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, rate=None): + return self.p.sample(self.shape, rate) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(rate) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Poisson distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.p = msd.Poisson([0.5], dtype=dtype.float32) + + def construct(self, x_): + return self.p.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_cdf = poisson_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) + cdf = CDF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Poisson distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.p = msd.Poisson([0.5], dtype=dtype.float32) + + def construct(self, x_): + return self.p.log_cdf(x_) + +def test_log_cdf(): + """ + Test log_cdf. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_logcdf = poisson_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) + logcdf = LogCDF() + x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Poisson distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.p = msd.Poisson(0.5, dtype=dtype.float32) + + def construct(self, x_): + return self.p.survival_function(x_) + +def test_survival(): + """ + Test survival function. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_survival = poisson_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) + survival = SF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = survival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Poisson distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.p = msd.Poisson(0.5, dtype=dtype.float32) + + def construct(self, x_): + return self.p.log_survival(x_) + +def test_log_survival(): + """ + Test log survival function. + """ + poisson_benchmark = stats.poisson(mu=0.5) + expect_logsurvival = poisson_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) + logsurvival = LogSF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = logsurvival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() diff --git a/tests/ut/python/nn/probability/distribution/test_poisson.py b/tests/ut/python/nn/probability/distribution/test_poisson.py new file mode 100644 index 0000000000..318ef31ef5 --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_poisson.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.Poisson. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + p = msd.Poisson() + assert isinstance(p, msd.Distribution) + p = msd.Poisson([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) + assert isinstance(p, msd.Distribution) + +def test_type(): + with pytest.raises(TypeError): + msd.Poisson([0.1], dtype=dtype.bool_) + +def test_name(): + with pytest.raises(TypeError): + msd.Poisson([0.1], name=1.0) + +def test_seed(): + with pytest.raises(TypeError): + msd.Poisson([0.1], seed='seed') + +def test_rate(): + """ + Invalid rate. + """ + with pytest.raises(ValueError): + msd.Poisson([-0.1], dtype=dtype.float32) + with pytest.raises(ValueError): + msd.Poisson([0.0], dtype=dtype.float32) + +class PoissonProb(nn.Cell): + """ + Poisson distribution: initialize with rate. + """ + def __init__(self): + super(PoissonProb, self).__init__() + self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) + + def construct(self, value): + prob = self.p.prob(value) + log_prob = self.p.log_prob(value) + cdf = self.p.cdf(value) + log_cdf = self.p.log_cdf(value) + sf = self.p.survival_function(value) + log_sf = self.p.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_poisson_prob(): + """ + Test probability functions: passing value through construct. + """ + net = PoissonProb() + value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class PoissonProb1(nn.Cell): + """ + Poisson distribution: initialize without rate. + """ + def __init__(self): + super(PoissonProb1, self).__init__() + self.p = msd.Poisson(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.p.prob(value, rate) + log_prob = self.p.log_prob(value, rate) + cdf = self.p.cdf(value, rate) + log_cdf = self.p.log_cdf(value, rate) + sf = self.p.survival_function(value, rate) + log_sf = self.p.log_survival(value, rate) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_poisson_prob1(): + """ + Test probability functions: passing value/rate through construct. + """ + net = PoissonProb1() + value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) + rate = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) + ans = net(value, rate) + assert isinstance(ans, Tensor) + +class PoissonBasics(nn.Cell): + """ + Test class: basic mean/sd/var/mode function. + """ + def __init__(self): + super(PoissonBasics, self).__init__() + self.p = msd.Poisson([2.3, 2.5], dtype=dtype.float32) + + def construct(self): + mean = self.p.mean() + sd = self.p.sd() + var = self.p.var() + return mean + sd + var + +def test_bascis(): + """ + Test mean/sd/var/mode functionality of Poisson distribution. + """ + net = PoissonBasics() + ans = net() + assert isinstance(ans, Tensor) + +class PoissonConstruct(nn.Cell): + """ + Poisson distribution: going through construct. + """ + def __init__(self): + super(PoissonConstruct, self).__init__() + self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) + self.p1 = msd.Poisson(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.p('prob', value) + prob1 = self.p('prob', value, rate) + prob2 = self.p1('prob', value, rate) + return prob + prob1 + prob2 + +def test_poisson_construct(): + """ + Test probability function going through construct. + """ + net = PoissonConstruct() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor)