From 0aa26c181506d220433b137fe4538ce2271beb68 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Fri, 26 Jun 2020 15:48:15 -0400 Subject: [PATCH 1/2] add high level abstract class Distribution and two example class: Bernoulli and Normal --- mindspore/nn/__init__.py | 6 +- mindspore/nn/distribution/__init__.py | 27 ++ mindspore/nn/distribution/_utils/__init__.py | 24 ++ mindspore/nn/distribution/_utils/utils.py | 190 +++++++++++++ mindspore/nn/distribution/bernoulli.py | 126 +++++++++ mindspore/nn/distribution/distribution.py | 232 +++++++++++++++ mindspore/nn/distribution/normal.py | 124 ++++++++ .../test_distribution/test_bernoulli.py | 128 +++++++++ .../ascend/test_distribution/test_normal.py | 130 +++++++++ tests/ut/python/nn/test_distribution.py | 266 ++++++++++++++++++ 10 files changed, 1252 insertions(+), 1 deletion(-) create mode 100644 mindspore/nn/distribution/__init__.py create mode 100644 mindspore/nn/distribution/_utils/__init__.py create mode 100644 mindspore/nn/distribution/_utils/utils.py create mode 100644 mindspore/nn/distribution/bernoulli.py create mode 100644 mindspore/nn/distribution/distribution.py create mode 100644 mindspore/nn/distribution/normal.py create mode 100644 tests/st/ops/ascend/test_distribution/test_bernoulli.py create mode 100644 tests/st/ops/ascend/test_distribution/test_normal.py create mode 100644 tests/ut/python/nn/test_distribution.py diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index 8d5e7d3b0a..e5c133a9a6 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -17,13 +17,15 @@ Neural Networks Cells. Pre-defined building blocks or computing units to construct Neural Networks. """ -from . import layer, loss, optim, metrics, wrap +from . import layer, loss, optim, metrics, wrap, distribution from .cell import Cell, GraphKernel from .layer import * from .loss import * from .optim import * from .metrics import * from .wrap import * +from .distribution import * + __all__ = ["Cell", "GraphKernel"] __all__.extend(layer.__all__) @@ -31,5 +33,7 @@ __all__.extend(loss.__all__) __all__.extend(optim.__all__) __all__.extend(metrics.__all__) __all__.extend(wrap.__all__) +__all__.extend(distribution.__all__) + __all__.sort() diff --git a/mindspore/nn/distribution/__init__.py b/mindspore/nn/distribution/__init__.py new file mode 100644 index 0000000000..55b4b03ef7 --- /dev/null +++ b/mindspore/nn/distribution/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution. + +The high-level components(Distributions) used to construct the probabilistic network. +""" + +from .distribution import Distribution +from .normal import Normal +from .bernoulli import Bernoulli + +__all__ = ['Distribution', + 'Normal', + 'Bernoulli',] diff --git a/mindspore/nn/distribution/_utils/__init__.py b/mindspore/nn/distribution/_utils/__init__.py new file mode 100644 index 0000000000..816485643a --- /dev/null +++ b/mindspore/nn/distribution/_utils/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution operation utility functions. +""" +from .utils import * + +__all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', + 'calc_batch_size', 'check_greater', + 'check_greater_equal_zero', + 'calc_broadcast_shape_from_param', + 'check_scalar_from_param', 'check_prob'] diff --git a/mindspore/nn/distribution/_utils/utils.py b/mindspore/nn/distribution/_utils/utils.py new file mode 100644 index 0000000000..0cb9c3cc68 --- /dev/null +++ b/mindspore/nn/distribution/_utils/utils.py @@ -0,0 +1,190 @@ + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utitly functions to help distribution class.""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import _utils as utils +from ....common.tensor import Tensor +from ....common import dtype as mstype + + +def check_scalar(value): + """ + Check if input value is a scalar. + """ + return np.isscalar(value) + + +def cast_to_tensor(t, dtype=mstype.float32): + """ + Cast an user input value into a Tensor of dtype. + + Args: + t (int/float/list/numpy.ndarray/Tensor). + dtype (mindspore.dtype). + + Raises: + RuntimeError: if t cannot be cast to Tensor. + + Outputs: + Tensor. + """ + if isinstance(t, Tensor): + #check if the Tensor in shape of Tensor(4) + if t.dim() == 0: + value = t.asnumpy() + return Tensor([t], dtype=dtype) + #convert the type of tensor to dtype + t.set_dtype(dtype) + return t + if isinstance(t, (list, np.ndarray)): + return Tensor(t, dtype=dtype) + if check_scalar(t): + return Tensor([t], dtype=dtype) + raise RuntimeError("Input type is not supported.") + +def calc_batch_size(batch_shape): + """ + Calculate the size of a given batch_shape. + + Args: + batch_shape (tuple) + + Outputs: + int. + """ + return int(np.prod(batch_shape)) + +def convert_to_batch(t, batch_shape, dtype): + """ + Convert a Tensor to a given batch shape. + + Args: + t (Tensor) + batch_shape (tuple) + dtype (mindspore.dtype) + Raises: + RuntimeError: if the converison cannot be done. + + Outputs: + Tensor, with shape of batch_shape. + """ + t = cast_to_tensor(t, dtype) + reshape = P.Reshape() + if t.shape != batch_shape: + mul = calc_batch_size(batch_shape) // t.size() + if (calc_batch_size(batch_shape) % t.size()) != 0: + raise RuntimeError("Cannot cast the tensor to the given batch shape.") + temp = list(t.asnumpy()) * mul + return reshape(Tensor(temp), batch_shape) + return t + +def check_scalar_from_param(params): + """ + Check if params are all scalars. + + Args: + params (dict): parameters used to initialized distribution. + + Notes: String parameters are excluded. + """ + for value in params.values(): + if isinstance(value, (str, type(params['dtype']))): + continue + elif check_scalar(value): + continue + else: + return False + return True + + +def calc_broadcast_shape_from_param(params): + """ + Calculate the broadcast shape from params. + + Args: + params (dict): parameters used to initialized distribution. + + Outputs: + tuple. + """ + broadcast_shape = [] + for value in params.values(): + if isinstance(value, (str, type(params['dtype']))): + continue + if value is None: + return None + value_t = cast_to_tensor(value, params['dtype']) + broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) + return tuple(broadcast_shape) + +def check_greater_equal_zero(value, name): + """ + Check if the given Tensor is greater zero. + + Args: + value (Tensor) + name (str) : name of the value. + + Raises: + ValueError: if the input value is less than zero. + + """ + less = P.Less() + zeros = Tensor([0.0], dtype=value.dtype) + value = less(value, zeros) + if value.asnumpy().any(): + raise ValueError('{} should be greater than zero.'.format(name)) + +def check_greater(a, b, name_a, name_b): + """ + Check if Tensor b is strictly greater than Tensor a. + + Args: + a (Tensor) + b (Tensor) + name_a (str): name of Tensor_a. + name_b (str): name of Tensor_b. + + Raises: + ValueError: if b is less than or equal to a + """ + less = P.Less() + value = less(a, b) + if not value.asnumpy().all(): + raise ValueError('{} should be less than {}'.format(name_a, name_b)) + + +def check_prob(p): + """ + Check if p is a proper probability, i.e. 0 <= p <=1. + + Args: + p (Tensor): value to check. + + Raises: + ValueError: if p is not a proper probability. + """ + less = P.Less() + greater = P.Greater() + zeros = Tensor([0.0], dtype=p.dtype) + ones = Tensor([1.0], dtype=p.dtype) + comp = less(p, zeros) + if comp.asnumpy().any(): + raise ValueError('Probabilities should be greater than or equal to zero') + comp = greater(p, ones) + if comp.asnumpy().any(): + raise ValueError('Probabilities should be less than or equal to one') diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py new file mode 100644 index 0000000000..04ecb5a37e --- /dev/null +++ b/mindspore/nn/distribution/bernoulli.py @@ -0,0 +1,126 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bernoulli Distribution""" +from mindspore.ops import operations as P +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_prob +from ...common import dtype as mstype + +class Bernoulli(Distribution): + """ + Example class: Bernoulli Distribution. + + Args: + probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome. + dtype (mindspore.dtype): type of the distribution, default to int32. + + Note: + probs should be proper probabilities (0 <= p <= 1). + + Examples: + >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 + >>> b = nn.Bernoulli(0.5, dtype = dtype.int32) + >>> # The following create two independent Bernoulli distributions + >>> b = nn.Bernoulli([0.7, 0.2], dtype = dtype.int32) + """ + + def __init__(self, + probs=None, + dtype=mstype.int32, + name="Bernoulli"): + """ + Constructor of Bernoulli distribution. + """ + param = dict(locals()) + super(Bernoulli, self).__init__(dtype, name, param) + if probs is not None: + self._probs = cast_to_tensor(probs) + # check if the input probability is valid + check_prob(self._probs) + else: + self._probs = probs + + # ops needed for the class + self.log = P.Log() + self.add = P.TensorAdd() + self.mul = P.Mul() + self.sqrt = P.Sqrt() + self.realdiv = P.RealDiv() + + + def probs(self): + """ + Returns the probability for the outcome is 1. + """ + return self._probs + + def _mean(self): + r""" + .. math:: + MEAN(B) = probs1 + """ + + return self._probs + + def _var(self): + r""" + .. math:: + VAR(B) = probs1 * probs0 + """ + probs0 = self.add(1, -1 * self._probs) + return self.mul(probs0, self._probs) + + def _prob(self, name, value, probs=None): + r""" + pmf of Bernoulli distribution. + + Args: + name (str): name of the function. Should be "prob" when passed in from construct. + value (Tensor): a Tensor composed of only zeros and ones. + probs (Tensor): probability of outcome is 1. Default to self._probs. + + .. math:: + pmf(k) = probs1 if k = 1; + pmf(k) = probs0 if k = 0; + """ + probs1 = self._probs if probs is None else probs + probs0 = self.add(1, -1 * probs1) + return self.add(self.mul(probs1, value), + self.mul(probs0, self.add(1, -1 * value))) + + def _kl_loss(self, name, dist, probs1_b): + r""" + Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. + dist (str): type of the distributions. Should be "Bernoulli" in this case. + probs1_b (Tensor): probs1 of distribution b. + + .. math:: + KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + + probs0_a * \log(\fract{probs0_a}{probs0_b}) + """ + if dist == 'Bernoulli': + probs1_a = self._probs + probs0_a = self.add(1, -1 * probs1_a) + probs0_b = self.add(1, -1 * probs1_b) + return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), + probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) + return None + + def extend_repr(self): + str_info = 'probs={}'.format(self._probs) + return str_info diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py new file mode 100644 index 0000000000..dcf34037dc --- /dev/null +++ b/mindspore/nn/distribution/distribution.py @@ -0,0 +1,232 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""basic""" +from ..cell import Cell +from ._utils.utils import calc_broadcast_shape_from_param + + +class Distribution(Cell): + """ + Base class for all mathematical distributions. + + Note: + Derived class should override operations such as ,_mean, _prob, + and _log_prob. Functions should be called through construct when + used inside a network in the form of function name followed by + arguments. + + Examples: + >>> class MyNormalDistribution(Distribution): + >>> def __init__(self): + >>> super(MyDistribution, self).__init__() + >>> self._mean_value = Tensor([2.0,3.0]) + >>> self._sd_value = Tensor([2.0,3.0]) + >>> + >>> def _mean(self): + >>> return self._mean_value + + """ + def __init__(self, + dtype, + name, + param): + + """ + Constructor of distribution class. + """ + super(Distribution, self).__init__() + self._name = name + self._dtype = dtype + self._parameters = {} + # parsing parameters + for k in param.keys(): + if not(k == 'self' or k.startswith('_')): + self._parameters[k] = param[k] + # some attributes + self._broadcast_shape = calc_broadcast_shape_from_param( + self._parameters) + + # set the function to call according to the derived class's attributes + self._set_prob() + self._set_log_prob() + self._set_sd() + + def _set_prob(self): + """ + Set probability funtion based on the availability of _prob and _log_likehood. + """ + if hasattr(self, '_prob'): + self._call_prob = self._prob + elif hasattr(self, '_log_likelihood'): + self._call_prob = self._calc_prob_from_log_likelihood + + def _set_sd(self): + """ + Set standard deviation based on the availability of _sd and _var. + """ + if hasattr(self, '_sd'): + self._call_sd = self._sd + elif hasattr(self, '_var'): + self._call_sd = self._calc_sd_from_var + + def _set_log_prob(self): + """ + Set log probability based on the availability of _prob and _log_likelihood. + """ + if hasattr(self, '_log_likelihood'): + self._call_log_prob = self._log_likelihood + if hasattr(self, '_prob'): + self._call_log_prob = self._calc_log_prob_from_prob + + def log_likelihood(self, *args): + """ + Evaluate the log probability at the given value. + + Note: + value is casted to Tensor for further calculation. + + Args: + name (str): name of the calling function. + value (Tensor): values to be evaluated. + mean (Tensor): mean of the distirbution. Default: self.mean. + sd (Tensor): standard deviation of the distribution. Default: self.sd. + + Outputs: + Tensor, shape: broadcast_shape of the distribution. + """ + return self._call_log_prob(*args) + + def _calc_prob_from_log_likelihood(self, *args): + r""" + Evaluate prob from log probability. + + .. math:: + probability(x) = \exp(log_likehood(x)) + + Args: + name (str): name of the calling function. + value (Tensor): values to be evaluated. + mean (Tensor): mean of the distribution. Default: self.mean. + sd (Tensor): standard deviation of the distritbuion. Default: self.sd. + """ + return self.exp(self._log_likelihood(*args)) + + def _call_prob(self, *args): + """ + Raises: + NotImplementedError when derived class didn't override _prob or _log_likelihood. + """ + raise NotImplementedError('pdf/pmf is not implemented: {}'.format(type(self).__name__)) + + def _call_log_prob(self, *args): + """ + Raises: + NotImplementedError when derived class didn't override _prob or _log_likelihood. + """ + raise NotImplementedError('log_probability is not implemented: {}'.format(type(self).__name__)) + + def _call_sd(self): + """ + Raises: + NotImplementedError when derived class didn't override _sd or _var. + """ + raise NotImplementedError('standard deviation is not implemented: {}'.format(type(self).__name__)) + + def prob(self, *args): + """ + Evaluate the prob (pdf or pmf) at given value. + + Note: + value is casted to Tensor for further calculation. + + Args: + name (str): name of the calling function. + value (Tensor): values to be evaluated. + mean (Tensor): mean of the distribution. + sd (Tensor): standard deviation of the distritbuion. + + Outputs: + Tensor, shape: broadcast_shape of the distribution. + """ + return self._call_prob(*args) + + def _calc_log_prob_from_prob(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + log_prob(x) = \log(prob(x)) + """ + return self.log(self._prob(*args)) + + def kl_loss(self, **kwargs): + """ + Evaluate the KL divergence. Parameters of the second distribution should be + passed in through **kwargs. + + Outputs: + Tensor, shape: broadcast_shape of the distribution and input distribution. + """ + return self._kl_loss(**kwargs) + + def mean(self, **kwargs): + """ + Evaluate the mean. + + Outputs: + Tensor, shape: broadcast_shape of the distribution. + """ + return self._mean(**kwargs) + + def sd(self, **kwargs): + """ + Evaluate the standard deviation. + + Outputs: + Tensor, with shape of broadcast_shape of the distribution. + """ + return self._call_sd(**kwargs) + + def _calc_sd_from_var(self, **kwargs): + r""" + Evaluate log probability from probability. + + .. math:: + STD(x) = \sqrt(VAR(x)) + """ + return self.sqrt(self._var(**kwargs)) + + def construct(self, *inputs): + """ + Override construct in Cell. + + Args: + *inputs: inputs[0] is always the name of the function. + + Notes: + Always raise RuntimeError as Distribution should not be called directly. + """ + + if inputs[0] == 'log_prob': + return self._call_log_prob(*inputs) + if inputs[0] == 'prob': + return self._call_prob(*inputs) + if inputs[0] == 'kl_loss': + return self._kl_loss(*inputs) + if inputs[0] == 'mean': + return self._mean() + if inputs[0] == 'sd': + return self._call_sd() + return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py new file mode 100644 index 0000000000..be3e359a9e --- /dev/null +++ b/mindspore/nn/distribution/normal.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Normal Distribution""" +import numpy as np +from mindspore.ops import operations as P +from .distribution import Distribution +from ._utils.utils import convert_to_batch, check_greater_equal_zero +from ...common import dtype as mstype +from ...context import get_context + +class Normal(Distribution): + """ + Example class: Normal distribution. + + Args: + mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution + standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution + dtype (mindspore.dtype): type of the distribution + + Note: + Standard deviation should be greater than zero. + + Examples: + >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 + >>> n = nn.Normal(3.0, 4.0, dtype=dtype.float32) + >>> # The following create two independent normal distributions + >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=dtype.float32) + """ + + def __init__(self, + mean=None, + sd=None, + dtype=mstype.float32, + name="Normal"): + """ + Constructor of normal distribution. + """ + param = dict(locals()) + super(Normal, self).__init__(dtype, name, param) + if mean is not None and sd is not None: + self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) + self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) + #check validity of standard deviation + check_greater_equal_zero(self._sd_value, "Standard deviation") + else: + self._mean_value = mean + self._sd_value = sd + + #ops needed for the class + self.exp = P.Exp() + self.add = P.TensorAdd() + self.sq = P.Square() + self.log = P.Log() + self.sqrt = P.Sqrt() + self.realdiv = P.RealDiv() + self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + + def _expm1_by_step(self, x): + """ + Expm1 ops under GPU context. + """ + return self.add(self.exp(x), -1) + + def _mean(self): + """ + Mean of the distribution. + """ + return self._mean_value + + def _sd(self): + """ + Standard deviation of the distribution. + """ + return self._sd_value + + def _log_likelihood(self, name, value, mean=None, sd=None): + r""" + Evaluate log probability. + + .. math:: + L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) + """ + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), + 2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + return self.add(unnormalized_log_prob, neg_normalization) + + def _kl_loss(self, name, dist, mean, sd): + r""" + Evaluate Normal-Normal kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion passed in from construct. Should always be "kl_loss". + dist (str): type of the distributions. Should be "Normal" in this case. + mean (Tensor): mean of distribution b. + sd (Tensor): standard deviation distribution b. + + .. math:: + KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) + """ + if dist == 'Normal': + diff_log_scale = self.add(self.log(self._sd_value), - self.log(sd)) + squared_diff = self.sq(self.add(self.realdiv(self._mean_value, sd), - self.realdiv(mean, sd))) + return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) + return None + + def extend_repr(self): + str_info = 'mean={}, standard deviation={}'.format(self._mean_value, self._sd_value) + return str_info diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py new file mode 100644 index 0000000000..1137260512 --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -0,0 +1,128 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for bernoulli distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: probability of bernoulli distribution. + """ + def __init__(self): + super(Net, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('prob', x_) + +class Net1(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self): + super(Net1, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('log_prob', x_) + +class Net2(nn.Cell): + """ + Test class: kl_loss between bernoulli distributions. + """ + def __init__(self): + super(Net2, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('kl_loss', 'Bernoulli', x_) + +class Net3(nn.Cell): + """ + Test class: mean/sd of bernoulli distribution. + """ + def __init__(self): + super(Net3, self).__init__() + self.b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) + + @ms_function + def construct(self): + return self.b('mean'), self.b('sd') + +def test_pmf(): + """ + Test pmf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) + pdf = Net() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = pdf(x_) + print("expected_pmf: ", expect_pmf) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_pmf < tol).all() + +def test_log_likelihood(): + """ + Test log_pmf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) + logprob = Net1() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = logprob(x_) + print("expected_log_probability: ", expect_logpmf) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_logpmf < tol).all() + +def test_kl_loss(): + """ + Test kl_loss. + """ + probs1_a = 0.7 + probs1_b = 0.5 + probs0_a = 1 - probs1_a + probs0_b = 1 - probs1_b + expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) + kl_loss = Net2() + output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) + print("expected_kl_loss: ", expect_kl_loss) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_kl_loss < tol).all() + +def test_basics(): + """ + Test mean/standard deviation and probs. + """ + basics = Net3() + mean, sd = basics() + print("mean : ", mean) + print("sd : ", sd) + b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) + probs = b.probs() + print("probs is ", probs) diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py new file mode 100644 index 0000000000..9977f934ad --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -0,0 +1,130 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for normal distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: probability of normal distribution. + """ + def __init__(self): + super(Net, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('prob', x_) + +class Net1(nn.Cell): + """ + Test class: log probability of normal distribution. + """ + def __init__(self): + super(Net1, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('log_prob', x_) + +class Net2(nn.Cell): + """ + Test class: kl_loss of normal distribution. + """ + def __init__(self): + super(Net2, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + return self.n('kl_loss', 'Normal', x_, y_) + +class Net3(nn.Cell): + """ + Test class: mean/sd of normal distribution. + """ + def __init__(self): + super(Net3, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self): + return self.n('mean'), self.n('sd') + +def test_pdf(): + """ + Test pdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) + pdf = Net() + output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + print("expected_pdf: ", expect_pdf) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_pdf < tol).all() + +def test_log_likelihood(): + """ + Test log_pdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) + logprob = Net1() + output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) + print("expected_log_probability: ", expect_logpdf) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_logpdf < tol).all() + +def test_kl_loss(): + """ + Test kl_loss. + """ + mean_a = np.array([3.0]).astype(np.float32) + sd_a = np.array([4.0]).astype(np.float32) + + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + + diff_log_scale = np.log(sd_a) - np.log(sd_b) + squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) + expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale + + kl_loss = Net2() + mean = Tensor(mean_b, dtype=dtype.float32) + sd = Tensor(sd_b, dtype=dtype.float32) + output = kl_loss(mean, sd) + print("expected_kl_loss: ", expect_kl_loss) + print("ans: ", output.asnumpy()) + tol = 1e-6 + assert (output.asnumpy() - expect_kl_loss < tol).all() + +def test_basics(): + """ + Test mean/standard deviation. + """ + basics = Net3() + mean, sd = basics() + print("mean is ", mean) + print("sd is ", sd) diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py new file mode 100644 index 0000000000..dbb6bf523c --- /dev/null +++ b/tests/ut/python/nn/test_distribution.py @@ -0,0 +1,266 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution. + +Including Normal Distribution and Bernoulli Distribution. +""" +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + +def test_normal_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + +def test_no_arguments(): + """ + No args passed in during initialization. + """ + n = nn.Normal() + b = nn.Bernoulli() + print(n) + print(b) + +def test_with_arguments(): + """ + Args passed in during initialization. + """ + n = nn.Normal([3.0], [4.0], dtype=dtype.float32) + b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) + print(n) + print(b) + +class NormalProb(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb, self).__init__() + self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + x = self.normal('prob', value) + y = self.normal('log_prob', value) + return x, y + +def test_normal_prob(): + """ + Test pdf/log_pdf: passing value through construct. + """ + net = NormalProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + pdf, log_pdf = net(value) + print("pdf: ", pdf) + print("log_pdf: ", log_pdf) + +class NormalProb1(nn.Cell): + """ + Normal distribution: initialize without mean/sd. + """ + def __init__(self): + super(NormalProb1, self).__init__() + self.normal = nn.Normal() + + def construct(self, value, mean, sd): + x = self.normal('prob', value, mean, sd) + y = self.normal('log_prob', value, mean, sd) + return x, y + +def test_normal_prob1(): + """ + Test pdf/logpdf: passing mean/sd, value through construct. + """ + net = NormalProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + pdf, log_pdf = net(value, mean, sd) + print("pdf: ", pdf) + print("log_pdf: ", log_pdf) + + +class NormalProb2(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb2, self).__init__() + self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value, mean, sd): + x = self.normal('prob', value, mean, sd) + y = self.normal('log_prob', value, mean, sd) + return x, y + +def test_normal_prob2(): + """ + Test pdf/log_pdf: passing mean/sd through construct. + Overwrite original mean/sd. + """ + net = NormalProb2() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + pdf, log_pdf = net(value, mean, sd) + print("pdf: ", pdf) + print("log_pdf: ", log_pdf) + +class BernoulliProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb, self).__init__() + self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + x = self.bernoulli('prob', value) + y = self.bernoulli('log_prob', value) + return x, y + +def test_bernoulli_prob(): + """ + Test pmf/log_pmf: passing value through construct. + """ + net = BernoulliProb() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + ans = net(value) + print("pmf: ", ans) + print("log_pmf: ", ans) + + +class BernoulliProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliProb1, self).__init__() + self.bernoulli = nn.Bernoulli() + + def construct(self, value, probs): + x = self.bernoulli('prob', value, probs) + y = self.bernoulli('log_prob', value, probs) + return x, y + +def test_bernoulli_prob1(): + """ + Test pmf/log_pmf: passing probs through construct. + """ + net = BernoulliProb1() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + ans = net(value, probs) + print("pmf: ", ans) + print("log_pmf: ", ans) + + +class BernoulliProb2(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb2, self).__init__() + self.bernoulli = nn.Bernoulli(0.5) + + def construct(self, value, probs): + x = self.bernoulli('prob', value, probs) + y = self.bernoulli('log_prob', value, probs) + return x, y + +def test_bernoulli_prob2(): + """ + Test pmf/log_pmf: passing probs/value through construct. + Overwrite original probs. + """ + net = BernoulliProb2() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + ans = net(value, probs) + print("pmf: ", ans) + print("log_pmf: ", ans) + +class NormalKl(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(NormalKl, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + return self.n('kl_loss', 'Normal', x_, y_) + +class BernoulliKl(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(BernoulliKl, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b('kl_loss', 'Bernoulli', x_) + +def test_kl(): + """ + Test kl_loss function. + """ + nor_net = NormalKl() + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + mean = Tensor(mean_b, dtype=dtype.float32) + sd = Tensor(sd_b, dtype=dtype.float32) + output = nor_net(mean, sd) + print("normal-normal kl loss: ", output) + + ber_net = BernoulliKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + output = ber_net(probs_b) + print("bernoulli-bernoulli kl loss: ", output) + + +class NormalBernoulli(nn.Cell): + """ + Test class: basic mean/sd function. + """ + def __init__(self): + super(NormalBernoulli, self).__init__() + self.n = nn.Normal(3.0, 4.0, dtype=dtype.int32) + self.b = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self): + normal_mean = self.n('mean') + normal_sd = self.n('sd') + bernoulli_mean = self.b('mean') + bernoulli_sd = self.b('sd') + return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd + +def test_bascis(): + """ + Test mean/sd functionality of Normal and Bernoulli. + """ + net = NormalBernoulli() + normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() + print("Mean of Normal distribution: ", normal_mean) + print("Standard deviation of Normal distribution: ", normal_sd) + print("Mean of Bernoulli distribution: ", bernoulli_mean) + print("Standard deviation of Bernoulli distribution: ", bernoulli_sd) From bef1fc7f19fd8f26c37c5d512255e8c1aaf56556 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Fri, 26 Jun 2020 17:56:13 -0300 Subject: [PATCH 2/2] add sample functions in normal and bermoulli distributions --- mindspore/nn/distribution/_utils/utils.py | 85 +++++---- mindspore/nn/distribution/bernoulli.py | 85 ++++++--- mindspore/nn/distribution/distribution.py | 74 ++------ mindspore/nn/distribution/normal.py | 97 +++++++--- .../test_distribution/test_bernoulli.py | 45 +++-- .../ascend/test_distribution/test_normal.py | 46 +++-- tests/ut/python/nn/test_distribution.py | 179 ++++++++++++++---- 7 files changed, 409 insertions(+), 202 deletions(-) diff --git a/mindspore/nn/distribution/_utils/utils.py b/mindspore/nn/distribution/_utils/utils.py index 0cb9c3cc68..108cff6614 100644 --- a/mindspore/nn/distribution/_utils/utils.py +++ b/mindspore/nn/distribution/_utils/utils.py @@ -15,9 +15,9 @@ # ============================================================================ """Utitly functions to help distribution class.""" import numpy as np -from mindspore.ops import operations as P from mindspore.ops import _utils as utils -from ....common.tensor import Tensor +from ....common.tensor import Tensor, MetaTensor +from ....common.parameter import Parameter from ....common import dtype as mstype @@ -33,15 +33,17 @@ def cast_to_tensor(t, dtype=mstype.float32): Cast an user input value into a Tensor of dtype. Args: - t (int/float/list/numpy.ndarray/Tensor). - dtype (mindspore.dtype). + t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. + dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. Raises: RuntimeError: if t cannot be cast to Tensor. - Outputs: + Returns: Tensor. """ + if isinstance(t, Parameter): + return t if isinstance(t, Tensor): #check if the Tensor in shape of Tensor(4) if t.dim() == 0: @@ -61,9 +63,9 @@ def calc_batch_size(batch_shape): Calculate the size of a given batch_shape. Args: - batch_shape (tuple) + batch_shape (tuple): batch shape to be calculated. - Outputs: + Returns: int. """ return int(np.prod(batch_shape)) @@ -73,23 +75,26 @@ def convert_to_batch(t, batch_shape, dtype): Convert a Tensor to a given batch shape. Args: - t (Tensor) - batch_shape (tuple) - dtype (mindspore.dtype) + t (Tensor, Parameter): Tensor to be converted. + batch_shape (tuple): desired batch shape. + dtype (mindspore.dtype): desired dtype. + Raises: RuntimeError: if the converison cannot be done. - Outputs: + Returns: Tensor, with shape of batch_shape. """ + if isinstance(t, Parameter): + return t t = cast_to_tensor(t, dtype) - reshape = P.Reshape() if t.shape != batch_shape: mul = calc_batch_size(batch_shape) // t.size() if (calc_batch_size(batch_shape) % t.size()) != 0: raise RuntimeError("Cannot cast the tensor to the given batch shape.") temp = list(t.asnumpy()) * mul - return reshape(Tensor(temp), batch_shape) + temp = np.reshape(temp, batch_shape) + return Tensor(temp, dtype) return t def check_scalar_from_param(params): @@ -97,7 +102,7 @@ def check_scalar_from_param(params): Check if params are all scalars. Args: - params (dict): parameters used to initialized distribution. + params (dict): parameters used to initialize distribution. Notes: String parameters are excluded. """ @@ -116,9 +121,9 @@ def calc_broadcast_shape_from_param(params): Calculate the broadcast shape from params. Args: - params (dict): parameters used to initialized distribution. + params (dict): parameters used to initialize distribution. - Outputs: + Returns: tuple. """ broadcast_shape = [] @@ -127,7 +132,10 @@ def calc_broadcast_shape_from_param(params): continue if value is None: return None - value_t = cast_to_tensor(value, params['dtype']) + if isinstance(value, Parameter): + value_t = value.default_input + else: + value_t = cast_to_tensor(value, params['dtype']) broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) return tuple(broadcast_shape) @@ -136,36 +144,37 @@ def check_greater_equal_zero(value, name): Check if the given Tensor is greater zero. Args: - value (Tensor) + value (Tensor, Parameter): value to be checked. name (str) : name of the value. Raises: ValueError: if the input value is less than zero. """ - less = P.Less() - zeros = Tensor([0.0], dtype=value.dtype) - value = less(value, zeros) - if value.asnumpy().any(): - raise ValueError('{} should be greater than zero.'.format(name)) + if isinstance(value, Parameter): + if isinstance(value.default_input, MetaTensor): + return + value = value.default_input + comp = np.less(value.asnumpy(), np.zeros(value.shape)) + if comp.any(): + raise ValueError(f'{name} should be greater than zero.') def check_greater(a, b, name_a, name_b): """ Check if Tensor b is strictly greater than Tensor a. Args: - a (Tensor) - b (Tensor) + a (Tensor): input tensor a. + b (Tensor): input tensor b. name_a (str): name of Tensor_a. name_b (str): name of Tensor_b. Raises: ValueError: if b is less than or equal to a """ - less = P.Less() - value = less(a, b) - if not value.asnumpy().all(): - raise ValueError('{} should be less than {}'.format(name_a, name_b)) + comp = np.less(a.asnumpy(), b.asnumpy()) + if not comp.all(): + raise ValueError(f'{name_a} should be less than {name_b}') def check_prob(p): @@ -173,18 +182,18 @@ def check_prob(p): Check if p is a proper probability, i.e. 0 <= p <=1. Args: - p (Tensor): value to check. + p (Tensor, Parameter): value to be checked. Raises: ValueError: if p is not a proper probability. """ - less = P.Less() - greater = P.Greater() - zeros = Tensor([0.0], dtype=p.dtype) - ones = Tensor([1.0], dtype=p.dtype) - comp = less(p, zeros) - if comp.asnumpy().any(): + if isinstance(p, Parameter): + if isinstance(p.default_input, MetaTensor): + return + p = p.default_input + comp = np.less(p.asnumpy(), np.zeros(p.shape)) + if comp.any(): raise ValueError('Probabilities should be greater than or equal to zero') - comp = greater(p, ones) - if comp.asnumpy().any(): + comp = np.greater(p.asnumpy(), np.ones(p.shape)) + if comp.any(): raise ValueError('Probabilities should be less than or equal to one') diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py index 04ecb5a37e..d0d8a5b08a 100644 --- a/mindspore/nn/distribution/bernoulli.py +++ b/mindspore/nn/distribution/bernoulli.py @@ -23,21 +23,24 @@ class Bernoulli(Distribution): Example class: Bernoulli Distribution. Args: - probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome. - dtype (mindspore.dtype): type of the distribution, default to int32. + probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Bernoulli. Note: probs should be proper probabilities (0 <= p <= 1). Examples: >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 - >>> b = nn.Bernoulli(0.5, dtype = dtype.int32) + >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) >>> # The following create two independent Bernoulli distributions - >>> b = nn.Bernoulli([0.7, 0.2], dtype = dtype.int32) + >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) """ def __init__(self, probs=None, + seed=0, dtype=mstype.int32, name="Bernoulli"): """ @@ -47,7 +50,6 @@ class Bernoulli(Distribution): super(Bernoulli, self).__init__(dtype, name, param) if probs is not None: self._probs = cast_to_tensor(probs) - # check if the input probability is valid check_prob(self._probs) else: self._probs = probs @@ -58,7 +60,17 @@ class Bernoulli(Distribution): self.mul = P.Mul() self.sqrt = P.Sqrt() self.realdiv = P.RealDiv() + self.shape = P.Shape() + self.const = P.ScalarToArray() + self.less = P.Less() + self.cast = P.Cast() + self.normal = P.Normal(seed=seed) + self.erf = P.Erf() + self.sqrt = P.Sqrt() + def extend_repr(self): + str_info = f'probs = {self._probs}' + return str_info def probs(self): """ @@ -66,21 +78,25 @@ class Bernoulli(Distribution): """ return self._probs - def _mean(self): + def _mean(self, name='mean', probs1=None): r""" .. math:: MEAN(B) = probs1 """ + if name == 'mean': + return self._probs if probs1 is None else probs1 + return None - return self._probs - - def _var(self): + def _var(self, name='var', probs1=None): r""" .. math:: VAR(B) = probs1 * probs0 """ - probs0 = self.add(1, -1 * self._probs) - return self.mul(probs0, self._probs) + if name in ('sd', 'var'): + probs1 = self._probs if probs1 is None else probs1 + probs0 = self.add(1, -1 * probs1) + return self.mul(probs0, probs1) + return None def _prob(self, name, value, probs=None): r""" @@ -89,18 +105,20 @@ class Bernoulli(Distribution): Args: name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only zeros and ones. - probs (Tensor): probability of outcome is 1. Default to self._probs. + probs (Tensor): probability of outcome is 1. Default: self._probs. .. math:: pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ - probs1 = self._probs if probs is None else probs - probs0 = self.add(1, -1 * probs1) - return self.add(self.mul(probs1, value), - self.mul(probs0, self.add(1, -1 * value))) + if name in ('prob', 'log_prob'): + probs1 = self._probs if probs is None else probs + probs0 = self.add(1, -1 * probs1) + return self.add(self.mul(probs1, value), + self.mul(probs0, self.add(1, -1 * value))) + return None - def _kl_loss(self, name, dist, probs1_b): + def _kl_loss(self, name, dist, probs1_b, probs1_a=None): r""" Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). @@ -108,19 +126,42 @@ class Bernoulli(Distribution): name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. + probs1_a (Tensor): probs1 of distribution a. Default: self._probs. .. math:: KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + probs0_a * \log(\fract{probs0_a}{probs0_b}) """ - if dist == 'Bernoulli': - probs1_a = self._probs + if name == 'kl_loss' and dist == 'Bernoulli': + probs1_a = self._probs if probs1_a is None else probs1_a probs0_a = self.add(1, -1 * probs1_a) probs0_b = self.add(1, -1 * probs1_b) return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) return None - def extend_repr(self): - str_info = 'probs={}'.format(self._probs) - return str_info + def _sample(self, name, shape=(), probs=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + probs (Tensor): probs1 of the samples. Default: self._probs. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + probs1 = self._probs if probs is None else probs + batch_shape = self.shape(probs1) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sqrt_two = self.sqrt(self.const(2.0)) + sample_norm = self.normal(sample_shape, mean_zero, sd_one) + sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) + sample = self.less(sample_uniform, probs1) + sample = self.cast(sample, self._dtype) + return sample + return None diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py index dcf34037dc..1ed7906a9e 100644 --- a/mindspore/nn/distribution/distribution.py +++ b/mindspore/nn/distribution/distribution.py @@ -21,6 +21,11 @@ class Distribution(Cell): """ Base class for all mathematical distributions. + Args: + dtype (mindspore.dtype): type of the distribution. + name (str): name of the distribution. + param (dict): parameters used to initialize the distribution. + Note: Derived class should override operations such as ,_mean, _prob, and _log_prob. Functions should be called through construct when @@ -97,14 +102,8 @@ class Distribution(Cell): Note: value is casted to Tensor for further calculation. - Args: - name (str): name of the calling function. - value (Tensor): values to be evaluated. - mean (Tensor): mean of the distirbution. Default: self.mean. - sd (Tensor): standard deviation of the distribution. Default: self.sd. - - Outputs: - Tensor, shape: broadcast_shape of the distribution. + Returns: + Tensor, shape is the broadcast_shape of the distribution. """ return self._call_log_prob(*args) @@ -114,36 +113,9 @@ class Distribution(Cell): .. math:: probability(x) = \exp(log_likehood(x)) - - Args: - name (str): name of the calling function. - value (Tensor): values to be evaluated. - mean (Tensor): mean of the distribution. Default: self.mean. - sd (Tensor): standard deviation of the distritbuion. Default: self.sd. """ return self.exp(self._log_likelihood(*args)) - def _call_prob(self, *args): - """ - Raises: - NotImplementedError when derived class didn't override _prob or _log_likelihood. - """ - raise NotImplementedError('pdf/pmf is not implemented: {}'.format(type(self).__name__)) - - def _call_log_prob(self, *args): - """ - Raises: - NotImplementedError when derived class didn't override _prob or _log_likelihood. - """ - raise NotImplementedError('log_probability is not implemented: {}'.format(type(self).__name__)) - - def _call_sd(self): - """ - Raises: - NotImplementedError when derived class didn't override _sd or _var. - """ - raise NotImplementedError('standard deviation is not implemented: {}'.format(type(self).__name__)) - def prob(self, *args): """ Evaluate the prob (pdf or pmf) at given value. @@ -151,14 +123,8 @@ class Distribution(Cell): Note: value is casted to Tensor for further calculation. - Args: - name (str): name of the calling function. - value (Tensor): values to be evaluated. - mean (Tensor): mean of the distribution. - sd (Tensor): standard deviation of the distritbuion. - - Outputs: - Tensor, shape: broadcast_shape of the distribution. + Returns: + Tensor, shape is the broadcast_shape of the distribution. """ return self._call_prob(*args) @@ -176,8 +142,8 @@ class Distribution(Cell): Evaluate the KL divergence. Parameters of the second distribution should be passed in through **kwargs. - Outputs: - Tensor, shape: broadcast_shape of the distribution and input distribution. + Returns: + Tensor, shape is the broadcast_shape of the distribution and input distribution. """ return self._kl_loss(**kwargs) @@ -185,8 +151,8 @@ class Distribution(Cell): """ Evaluate the mean. - Outputs: - Tensor, shape: broadcast_shape of the distribution. + Returns: + Tensor, shape is the broadcast_shape of the distribution. """ return self._mean(**kwargs) @@ -194,19 +160,19 @@ class Distribution(Cell): """ Evaluate the standard deviation. - Outputs: - Tensor, with shape of broadcast_shape of the distribution. + Returns: + Tensor, shape is the broadcast_shape of the distribution. """ return self._call_sd(**kwargs) - def _calc_sd_from_var(self, **kwargs): + def _calc_sd_from_var(self, *args): r""" Evaluate log probability from probability. .. math:: STD(x) = \sqrt(VAR(x)) """ - return self.sqrt(self._var(**kwargs)) + return self.sqrt(self._var(*args)) def construct(self, *inputs): """ @@ -226,7 +192,9 @@ class Distribution(Cell): if inputs[0] == 'kl_loss': return self._kl_loss(*inputs) if inputs[0] == 'mean': - return self._mean() + return self._mean(*inputs) if inputs[0] == 'sd': - return self._call_sd() + return self._call_sd(*inputs) + if inputs[0] == 'sample': + return self._sample(*inputs) return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py index be3e359a9e..344dbd2eeb 100644 --- a/mindspore/nn/distribution/normal.py +++ b/mindspore/nn/distribution/normal.py @@ -25,23 +25,27 @@ class Normal(Distribution): Example class: Normal distribution. Args: - mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution - standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution - dtype (mindspore.dtype): type of the distribution + mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. + sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Normal. + Note: Standard deviation should be greater than zero. Examples: >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 - >>> n = nn.Normal(3.0, 4.0, dtype=dtype.float32) + >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) >>> # The following create two independent normal distributions - >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=dtype.float32) + >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) """ def __init__(self, mean=None, sd=None, + seed=0, dtype=mstype.float32, name="Normal"): """ @@ -52,7 +56,6 @@ class Normal(Distribution): if mean is not None and sd is not None: self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) - #check validity of standard deviation check_greater_equal_zero(self._sd_value, "Standard deviation") else: self._mean_value = mean @@ -61,11 +64,20 @@ class Normal(Distribution): #ops needed for the class self.exp = P.Exp() self.add = P.TensorAdd() + self.mul = P.Mul() self.sq = P.Square() self.log = P.Log() self.sqrt = P.Sqrt() self.realdiv = P.RealDiv() self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + self.normal = P.Normal(seed=seed) + self.shape = P.Shape() + self.zeroslike = P.ZerosLike() + self.const = P.ScalarToArray() + + def extend_repr(self): + str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' + return str_info def _expm1_by_step(self, x): """ @@ -73,17 +85,23 @@ class Normal(Distribution): """ return self.add(self.exp(x), -1) - def _mean(self): + def _mean(self, name='mean', mean=None, sd=None): """ Mean of the distribution. """ - return self._mean_value + if name == 'mean': + mean = self._mean_value if mean is None or sd is None else mean + return mean + return None - def _sd(self): + def _sd(self, name='sd', mean=None, sd=None): """ Standard deviation of the distribution. """ - return self._sd_value + if name in ('sd', 'var'): + sd = self._sd_value if mean is None or sd is None else sd + return sd + return None def _log_likelihood(self, name, value, mean=None, sd=None): r""" @@ -92,33 +110,60 @@ class Normal(Distribution): .. math:: L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) """ - mean = self._mean_value if mean is None else mean - sd = self._sd_value if sd is None else sd - unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), - 2. * self.sq(sd)) - neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) - return self.add(unnormalized_log_prob, neg_normalization) - - def _kl_loss(self, name, dist, mean, sd): + if name in ('prob', 'log_prob'): + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), + 2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + return self.add(unnormalized_log_prob, neg_normalization) + return None + + def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): r""" Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Args: name (str): name of the funtion passed in from construct. Should always be "kl_loss". dist (str): type of the distributions. Should be "Normal" in this case. - mean (Tensor): mean of distribution b. - sd (Tensor): standard deviation distribution b. + mean_b (Tensor): mean of distribution b. + sd_b (Tensor): standard deviation distribution b. + mean_a (Tensor): mean of distribution a. Default: self._mean_value. + sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. .. math:: KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) """ - if dist == 'Normal': - diff_log_scale = self.add(self.log(self._sd_value), - self.log(sd)) - squared_diff = self.sq(self.add(self.realdiv(self._mean_value, sd), - self.realdiv(mean, sd))) + if name == 'kl_loss' and dist == 'Normal': + mean_a = self._mean_value if mean_a is None else mean_a + sd_a = self._sd_value if sd_a is None else sd_a + diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) + squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) return None - def extend_repr(self): - str_info = 'mean={}, standard deviation={}'.format(self._mean_value, self._sd_value) - return str_info + def _sample(self, name, shape=(), mean=None, sd=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + mean (Tensor): mean of the samples. Default: self._mean_value. + sd (Tensor): standard deviation of the samples. Default: self._sd_value. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sample_norm = self.normal(sample_shape, mean_zero, sd_one) + sample = self.add(mean, self.mul(sample_norm, sd)) + return sample + return None diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py index 1137260512..5652d536c7 100644 --- a/tests/st/ops/ascend/test_distribution/test_bernoulli.py +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -65,12 +65,25 @@ class Net3(nn.Cell): """ def __init__(self): super(Net3, self).__init__() - self.b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) + self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) @ms_function def construct(self): return self.b('mean'), self.b('sd') +class Net4(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Net4, self).__init__() + self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + @ms_function + def construct(self, probs=None): + return self.b('sample', self.shape, probs) + def test_pmf(): """ Test pmf. @@ -80,10 +93,8 @@ def test_pmf(): pdf = Net() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) output = pdf(x_) - print("expected_pmf: ", expect_pmf) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_pmf < tol).all() + assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() def test_log_likelihood(): """ @@ -94,10 +105,8 @@ def test_log_likelihood(): logprob = Net1() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) output = logprob(x_) - print("expected_log_probability: ", expect_logpmf) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_logpmf < tol).all() + assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() def test_kl_loss(): """ @@ -110,10 +119,8 @@ def test_kl_loss(): expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) kl_loss = Net2() output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) - print("expected_kl_loss: ", expect_kl_loss) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_kl_loss < tol).all() + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() def test_basics(): """ @@ -121,8 +128,20 @@ def test_basics(): """ basics = Net3() mean, sd = basics() - print("mean : ", mean) - print("sd : ", sd) + expect_mean = [0.5, 0.5] + assert (mean.asnumpy() == expect_mean).all() + assert (sd.asnumpy() == expect_mean).all() b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) probs = b.probs() - print("probs is ", probs) + expect_probs = [0.7, 0.5] + tol = 1e-6 + assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + sample = Net4(shape) + output = sample() + assert output.shape == (2, 3, 2) diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py index 9977f934ad..52bb1173ee 100644 --- a/tests/st/ops/ascend/test_distribution/test_normal.py +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -65,12 +65,25 @@ class Net3(nn.Cell): """ def __init__(self): super(Net3, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) @ms_function def construct(self): return self.n('mean'), self.n('sd') +class Net4(nn.Cell): + """ + Test class: mean/sd of normal distribution. + """ + def __init__(self, shape, seed=0): + super(Net4, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, mean=None, sd=None): + return self.n('sample', self.shape, mean, sd) + def test_pdf(): """ Test pdf. @@ -79,10 +92,8 @@ def test_pdf(): expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) pdf = Net() output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) - print("expected_pdf: ", expect_pdf) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_pdf < tol).all() + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() def test_log_likelihood(): """ @@ -92,10 +103,8 @@ def test_log_likelihood(): expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) logprob = Net1() output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) - print("expected_log_probability: ", expect_logpdf) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_logpdf < tol).all() + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() def test_kl_loss(): """ @@ -115,10 +124,8 @@ def test_kl_loss(): mean = Tensor(mean_b, dtype=dtype.float32) sd = Tensor(sd_b, dtype=dtype.float32) output = kl_loss(mean, sd) - print("expected_kl_loss: ", expect_kl_loss) - print("ans: ", output.asnumpy()) tol = 1e-6 - assert (output.asnumpy() - expect_kl_loss < tol).all() + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() def test_basics(): """ @@ -126,5 +133,20 @@ def test_basics(): """ basics = Net3() mean, sd = basics() - print("mean is ", mean) - print("sd is ", sd) + expect_mean = [3.0, 3.0] + expect_sd = [2.0, 4.0] + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + mean = Tensor([2.0], dtype=dtype.float32) + sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) + sample = Net4(shape, seed=seed) + output = sample(mean, sd) + assert output.shape == (2, 3, 3) diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py index dbb6bf523c..845c64a110 100644 --- a/tests/ut/python/nn/test_distribution.py +++ b/tests/ut/python/nn/test_distribution.py @@ -36,18 +36,18 @@ def test_no_arguments(): No args passed in during initialization. """ n = nn.Normal() + assert isinstance(n, nn.Distribution) b = nn.Bernoulli() - print(n) - print(b) + assert isinstance(b, nn.Distribution) def test_with_arguments(): """ Args passed in during initialization. """ n = nn.Normal([3.0], [4.0], dtype=dtype.float32) + assert isinstance(n, nn.Distribution) b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) - print(n) - print(b) + assert isinstance(b, nn.Distribution) class NormalProb(nn.Cell): """ @@ -69,8 +69,8 @@ def test_normal_prob(): net = NormalProb() value = Tensor([0.5, 1.0], dtype=dtype.float32) pdf, log_pdf = net(value) - print("pdf: ", pdf) - print("log_pdf: ", log_pdf) + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) class NormalProb1(nn.Cell): """ @@ -94,9 +94,8 @@ def test_normal_prob1(): mean = Tensor([0.0], dtype=dtype.float32) sd = Tensor([1.0], dtype=dtype.float32) pdf, log_pdf = net(value, mean, sd) - print("pdf: ", pdf) - print("log_pdf: ", log_pdf) - + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) class NormalProb2(nn.Cell): """ @@ -121,8 +120,8 @@ def test_normal_prob2(): mean = Tensor([0.0], dtype=dtype.float32) sd = Tensor([1.0], dtype=dtype.float32) pdf, log_pdf = net(value, mean, sd) - print("pdf: ", pdf) - print("log_pdf: ", log_pdf) + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) class BernoulliProb(nn.Cell): """ @@ -133,9 +132,19 @@ class BernoulliProb(nn.Cell): self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) def construct(self, value): - x = self.bernoulli('prob', value) - y = self.bernoulli('log_prob', value) - return x, y + return self.bernoulli('prob', value) + +class BernoulliLogProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliLogProb, self).__init__() + self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + return self.bernoulli('log_prob', value) + def test_bernoulli_prob(): """ @@ -143,10 +152,17 @@ def test_bernoulli_prob(): """ net = BernoulliProb() value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - ans = net(value) - print("pmf: ", ans) - print("log_pmf: ", ans) + pmf = net(value) + assert isinstance(pmf, Tensor) +def test_bernoulli_log_prob(): + """ + Test pmf/log_pmf: passing value through construct. + """ + net = BernoulliLogProb() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + log_pmf = net(value) + assert isinstance(log_pmf, Tensor) class BernoulliProb1(nn.Cell): """ @@ -157,9 +173,19 @@ class BernoulliProb1(nn.Cell): self.bernoulli = nn.Bernoulli() def construct(self, value, probs): - x = self.bernoulli('prob', value, probs) - y = self.bernoulli('log_prob', value, probs) - return x, y + return self.bernoulli('prob', value, probs) + +class BernoulliLogProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliLogProb1, self).__init__() + self.bernoulli = nn.Bernoulli() + + def construct(self, value, probs): + return self.bernoulli('log_prob', value, probs) + def test_bernoulli_prob1(): """ @@ -168,10 +194,18 @@ def test_bernoulli_prob1(): net = BernoulliProb1() value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) probs = Tensor([0.3], dtype=dtype.float32) - ans = net(value, probs) - print("pmf: ", ans) - print("log_pmf: ", ans) + pmf = net(value, probs) + assert isinstance(pmf, Tensor) +def test_bernoulli_log_prob1(): + """ + Test pmf/log_pmf: passing probs through construct. + """ + net = BernoulliLogProb1() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + log_pmf = net(value, probs) + assert isinstance(log_pmf, Tensor) class BernoulliProb2(nn.Cell): """ @@ -182,9 +216,19 @@ class BernoulliProb2(nn.Cell): self.bernoulli = nn.Bernoulli(0.5) def construct(self, value, probs): - x = self.bernoulli('prob', value, probs) - y = self.bernoulli('log_prob', value, probs) - return x, y + return self.bernoulli('prob', value, probs) + +class BernoulliLogProb2(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliLogProb2, self).__init__() + self.bernoulli = nn.Bernoulli(0.5) + + def construct(self, value, probs): + return self.bernoulli('log_prob', value, probs) + def test_bernoulli_prob2(): """ @@ -194,9 +238,20 @@ def test_bernoulli_prob2(): net = BernoulliProb2() value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) probs = Tensor([0.3], dtype=dtype.float32) - ans = net(value, probs) - print("pmf: ", ans) - print("log_pmf: ", ans) + pmf = net(value, probs) + assert isinstance(pmf, Tensor) + +def test_bernoulli_log_prob2(): + """ + Test pmf/log_pmf: passing probs/value through construct. + Overwrite original probs. + """ + net = BernoulliLogProb2() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + log_pmf = net(value, probs) + assert isinstance(log_pmf, Tensor) + class NormalKl(nn.Cell): """ @@ -229,13 +284,61 @@ def test_kl(): sd_b = np.array([1.0]).astype(np.float32) mean = Tensor(mean_b, dtype=dtype.float32) sd = Tensor(sd_b, dtype=dtype.float32) - output = nor_net(mean, sd) - print("normal-normal kl loss: ", output) + loss = nor_net(mean, sd) + assert isinstance(loss, Tensor) ber_net = BernoulliKl() probs_b = Tensor([0.3], dtype=dtype.float32) - output = ber_net(probs_b) - print("bernoulli-bernoulli kl loss: ", output) + loss = ber_net(probs_b) + assert isinstance(loss, Tensor) + + +class NormalKlNoArgs(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + No args during initialization. + """ + def __init__(self): + super(NormalKlNoArgs, self).__init__() + self.n = nn.Normal(dtype=dtype.float32) + + def construct(self, x_, y_, w_, v_): + return self.n('kl_loss', 'Normal', x_, y_, w_, v_) + +class BernoulliKlNoArgs(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + No args during initialization. + """ + def __init__(self): + super(BernoulliKlNoArgs, self).__init__() + self.b = nn.Bernoulli(dtype=dtype.int32) + + def construct(self, x_, y_): + return self.b('kl_loss', 'Bernoulli', x_, y_) + +def test_kl_no_args(): + """ + Test kl_loss function. + """ + nor_net = NormalKlNoArgs() + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + mean_a = np.array([2.0]).astype(np.float32) + sd_a = np.array([3.0]).astype(np.float32) + mean_b = Tensor(mean_b, dtype=dtype.float32) + sd_b = Tensor(sd_b, dtype=dtype.float32) + mean_a = Tensor(mean_a, dtype=dtype.float32) + sd_a = Tensor(sd_a, dtype=dtype.float32) + loss = nor_net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(loss, Tensor) + + ber_net = BernoulliKlNoArgs() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + loss = ber_net(probs_b, probs_a) + assert isinstance(loss, Tensor) + class NormalBernoulli(nn.Cell): @@ -244,7 +347,7 @@ class NormalBernoulli(nn.Cell): """ def __init__(self): super(NormalBernoulli, self).__init__() - self.n = nn.Normal(3.0, 4.0, dtype=dtype.int32) + self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) self.b = nn.Bernoulli(0.5, dtype=dtype.int32) def construct(self): @@ -260,7 +363,7 @@ def test_bascis(): """ net = NormalBernoulli() normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() - print("Mean of Normal distribution: ", normal_mean) - print("Standard deviation of Normal distribution: ", normal_sd) - print("Mean of Bernoulli distribution: ", bernoulli_mean) - print("Standard deviation of Bernoulli distribution: ", bernoulli_sd) + assert isinstance(normal_mean, Tensor) + assert isinstance(normal_sd, Tensor) + assert isinstance(bernoulli_mean, Tensor) + assert isinstance(bernoulli_sd, Tensor)