Merge pull request !2605 from XunDeng/pp_poc_v3tags/v0.6.0-beta
| @@ -17,13 +17,15 @@ Neural Networks Cells. | |||
| Pre-defined building blocks or computing units to construct Neural Networks. | |||
| """ | |||
| from . import layer, loss, optim, metrics, wrap | |||
| from . import layer, loss, optim, metrics, wrap, distribution | |||
| from .cell import Cell, GraphKernel | |||
| from .layer import * | |||
| from .loss import * | |||
| from .optim import * | |||
| from .metrics import * | |||
| from .wrap import * | |||
| from .distribution import * | |||
| __all__ = ["Cell", "GraphKernel"] | |||
| __all__.extend(layer.__all__) | |||
| @@ -31,5 +33,7 @@ __all__.extend(loss.__all__) | |||
| __all__.extend(optim.__all__) | |||
| __all__.extend(metrics.__all__) | |||
| __all__.extend(wrap.__all__) | |||
| __all__.extend(distribution.__all__) | |||
| __all__.sort() | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Distribution. | |||
| The high-level components(Distributions) used to construct the probabilistic network. | |||
| """ | |||
| from .distribution import Distribution | |||
| from .normal import Normal | |||
| from .bernoulli import Bernoulli | |||
| __all__ = ['Distribution', | |||
| 'Normal', | |||
| 'Bernoulli',] | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Distribution operation utility functions. | |||
| """ | |||
| from .utils import * | |||
| __all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', | |||
| 'calc_batch_size', 'check_greater', | |||
| 'check_greater_equal_zero', | |||
| 'calc_broadcast_shape_from_param', | |||
| 'check_scalar_from_param', 'check_prob'] | |||
| @@ -0,0 +1,199 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utitly functions to help distribution class.""" | |||
| import numpy as np | |||
| from mindspore.ops import _utils as utils | |||
| from ....common.tensor import Tensor, MetaTensor | |||
| from ....common.parameter import Parameter | |||
| from ....common import dtype as mstype | |||
| def check_scalar(value): | |||
| """ | |||
| Check if input value is a scalar. | |||
| """ | |||
| return np.isscalar(value) | |||
| def cast_to_tensor(t, dtype=mstype.float32): | |||
| """ | |||
| Cast an user input value into a Tensor of dtype. | |||
| Args: | |||
| t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. | |||
| dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. | |||
| Raises: | |||
| RuntimeError: if t cannot be cast to Tensor. | |||
| Returns: | |||
| Tensor. | |||
| """ | |||
| if isinstance(t, Parameter): | |||
| return t | |||
| if isinstance(t, Tensor): | |||
| #check if the Tensor in shape of Tensor(4) | |||
| if t.dim() == 0: | |||
| value = t.asnumpy() | |||
| return Tensor([t], dtype=dtype) | |||
| #convert the type of tensor to dtype | |||
| t.set_dtype(dtype) | |||
| return t | |||
| if isinstance(t, (list, np.ndarray)): | |||
| return Tensor(t, dtype=dtype) | |||
| if check_scalar(t): | |||
| return Tensor([t], dtype=dtype) | |||
| raise RuntimeError("Input type is not supported.") | |||
| def calc_batch_size(batch_shape): | |||
| """ | |||
| Calculate the size of a given batch_shape. | |||
| Args: | |||
| batch_shape (tuple): batch shape to be calculated. | |||
| Returns: | |||
| int. | |||
| """ | |||
| return int(np.prod(batch_shape)) | |||
| def convert_to_batch(t, batch_shape, dtype): | |||
| """ | |||
| Convert a Tensor to a given batch shape. | |||
| Args: | |||
| t (Tensor, Parameter): Tensor to be converted. | |||
| batch_shape (tuple): desired batch shape. | |||
| dtype (mindspore.dtype): desired dtype. | |||
| Raises: | |||
| RuntimeError: if the converison cannot be done. | |||
| Returns: | |||
| Tensor, with shape of batch_shape. | |||
| """ | |||
| if isinstance(t, Parameter): | |||
| return t | |||
| t = cast_to_tensor(t, dtype) | |||
| if t.shape != batch_shape: | |||
| mul = calc_batch_size(batch_shape) // t.size() | |||
| if (calc_batch_size(batch_shape) % t.size()) != 0: | |||
| raise RuntimeError("Cannot cast the tensor to the given batch shape.") | |||
| temp = list(t.asnumpy()) * mul | |||
| temp = np.reshape(temp, batch_shape) | |||
| return Tensor(temp, dtype) | |||
| return t | |||
| def check_scalar_from_param(params): | |||
| """ | |||
| Check if params are all scalars. | |||
| Args: | |||
| params (dict): parameters used to initialize distribution. | |||
| Notes: String parameters are excluded. | |||
| """ | |||
| for value in params.values(): | |||
| if isinstance(value, (str, type(params['dtype']))): | |||
| continue | |||
| elif check_scalar(value): | |||
| continue | |||
| else: | |||
| return False | |||
| return True | |||
| def calc_broadcast_shape_from_param(params): | |||
| """ | |||
| Calculate the broadcast shape from params. | |||
| Args: | |||
| params (dict): parameters used to initialize distribution. | |||
| Returns: | |||
| tuple. | |||
| """ | |||
| broadcast_shape = [] | |||
| for value in params.values(): | |||
| if isinstance(value, (str, type(params['dtype']))): | |||
| continue | |||
| if value is None: | |||
| return None | |||
| if isinstance(value, Parameter): | |||
| value_t = value.default_input | |||
| else: | |||
| value_t = cast_to_tensor(value, params['dtype']) | |||
| broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) | |||
| return tuple(broadcast_shape) | |||
| def check_greater_equal_zero(value, name): | |||
| """ | |||
| Check if the given Tensor is greater zero. | |||
| Args: | |||
| value (Tensor, Parameter): value to be checked. | |||
| name (str) : name of the value. | |||
| Raises: | |||
| ValueError: if the input value is less than zero. | |||
| """ | |||
| if isinstance(value, Parameter): | |||
| if isinstance(value.default_input, MetaTensor): | |||
| return | |||
| value = value.default_input | |||
| comp = np.less(value.asnumpy(), np.zeros(value.shape)) | |||
| if comp.any(): | |||
| raise ValueError(f'{name} should be greater than zero.') | |||
| def check_greater(a, b, name_a, name_b): | |||
| """ | |||
| Check if Tensor b is strictly greater than Tensor a. | |||
| Args: | |||
| a (Tensor): input tensor a. | |||
| b (Tensor): input tensor b. | |||
| name_a (str): name of Tensor_a. | |||
| name_b (str): name of Tensor_b. | |||
| Raises: | |||
| ValueError: if b is less than or equal to a | |||
| """ | |||
| comp = np.less(a.asnumpy(), b.asnumpy()) | |||
| if not comp.all(): | |||
| raise ValueError(f'{name_a} should be less than {name_b}') | |||
| def check_prob(p): | |||
| """ | |||
| Check if p is a proper probability, i.e. 0 <= p <=1. | |||
| Args: | |||
| p (Tensor, Parameter): value to be checked. | |||
| Raises: | |||
| ValueError: if p is not a proper probability. | |||
| """ | |||
| if isinstance(p, Parameter): | |||
| if isinstance(p.default_input, MetaTensor): | |||
| return | |||
| p = p.default_input | |||
| comp = np.less(p.asnumpy(), np.zeros(p.shape)) | |||
| if comp.any(): | |||
| raise ValueError('Probabilities should be greater than or equal to zero') | |||
| comp = np.greater(p.asnumpy(), np.ones(p.shape)) | |||
| if comp.any(): | |||
| raise ValueError('Probabilities should be less than or equal to one') | |||
| @@ -0,0 +1,167 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Bernoulli Distribution""" | |||
| from mindspore.ops import operations as P | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob | |||
| from ...common import dtype as mstype | |||
| class Bernoulli(Distribution): | |||
| """ | |||
| Example class: Bernoulli Distribution. | |||
| Args: | |||
| probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. | |||
| seed (int): seed to use in sampling. Default: 0. | |||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. | |||
| name (str): name of the distribution. Default: Bernoulli. | |||
| Note: | |||
| probs should be proper probabilities (0 <= p <= 1). | |||
| Examples: | |||
| >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 | |||
| >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) | |||
| >>> # The following create two independent Bernoulli distributions | |||
| >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) | |||
| """ | |||
| def __init__(self, | |||
| probs=None, | |||
| seed=0, | |||
| dtype=mstype.int32, | |||
| name="Bernoulli"): | |||
| """ | |||
| Constructor of Bernoulli distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| super(Bernoulli, self).__init__(dtype, name, param) | |||
| if probs is not None: | |||
| self._probs = cast_to_tensor(probs) | |||
| check_prob(self._probs) | |||
| else: | |||
| self._probs = probs | |||
| # ops needed for the class | |||
| self.log = P.Log() | |||
| self.add = P.TensorAdd() | |||
| self.mul = P.Mul() | |||
| self.sqrt = P.Sqrt() | |||
| self.realdiv = P.RealDiv() | |||
| self.shape = P.Shape() | |||
| self.const = P.ScalarToArray() | |||
| self.less = P.Less() | |||
| self.cast = P.Cast() | |||
| self.normal = P.Normal(seed=seed) | |||
| self.erf = P.Erf() | |||
| self.sqrt = P.Sqrt() | |||
| def extend_repr(self): | |||
| str_info = f'probs = {self._probs}' | |||
| return str_info | |||
| def probs(self): | |||
| """ | |||
| Returns the probability for the outcome is 1. | |||
| """ | |||
| return self._probs | |||
| def _mean(self, name='mean', probs1=None): | |||
| r""" | |||
| .. math:: | |||
| MEAN(B) = probs1 | |||
| """ | |||
| if name == 'mean': | |||
| return self._probs if probs1 is None else probs1 | |||
| return None | |||
| def _var(self, name='var', probs1=None): | |||
| r""" | |||
| .. math:: | |||
| VAR(B) = probs1 * probs0 | |||
| """ | |||
| if name in ('sd', 'var'): | |||
| probs1 = self._probs if probs1 is None else probs1 | |||
| probs0 = self.add(1, -1 * probs1) | |||
| return self.mul(probs0, probs1) | |||
| return None | |||
| def _prob(self, name, value, probs=None): | |||
| r""" | |||
| pmf of Bernoulli distribution. | |||
| Args: | |||
| name (str): name of the function. Should be "prob" when passed in from construct. | |||
| value (Tensor): a Tensor composed of only zeros and ones. | |||
| probs (Tensor): probability of outcome is 1. Default: self._probs. | |||
| .. math:: | |||
| pmf(k) = probs1 if k = 1; | |||
| pmf(k) = probs0 if k = 0; | |||
| """ | |||
| if name in ('prob', 'log_prob'): | |||
| probs1 = self._probs if probs is None else probs | |||
| probs0 = self.add(1, -1 * probs1) | |||
| return self.add(self.mul(probs1, value), | |||
| self.mul(probs0, self.add(1, -1 * value))) | |||
| return None | |||
| def _kl_loss(self, name, dist, probs1_b, probs1_a=None): | |||
| r""" | |||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | |||
| Args: | |||
| name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. | |||
| dist (str): type of the distributions. Should be "Bernoulli" in this case. | |||
| probs1_b (Tensor): probs1 of distribution b. | |||
| probs1_a (Tensor): probs1 of distribution a. Default: self._probs. | |||
| .. math:: | |||
| KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + | |||
| probs0_a * \log(\fract{probs0_a}{probs0_b}) | |||
| """ | |||
| if name == 'kl_loss' and dist == 'Bernoulli': | |||
| probs1_a = self._probs if probs1_a is None else probs1_a | |||
| probs0_a = self.add(1, -1 * probs1_a) | |||
| probs0_b = self.add(1, -1 * probs1_b) | |||
| return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), | |||
| probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) | |||
| return None | |||
| def _sample(self, name, shape=(), probs=None): | |||
| """ | |||
| Sampling. | |||
| Args: | |||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||
| shape (tuple): shape of the sample. Default: (). | |||
| probs (Tensor): probs1 of the samples. Default: self._probs. | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| if name == 'sample': | |||
| probs1 = self._probs if probs is None else probs | |||
| batch_shape = self.shape(probs1) | |||
| sample_shape = shape + batch_shape | |||
| mean_zero = self.const(0.0) | |||
| sd_one = self.const(1.0) | |||
| sqrt_two = self.sqrt(self.const(2.0)) | |||
| sample_norm = self.normal(sample_shape, mean_zero, sd_one) | |||
| sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) | |||
| sample = self.less(sample_uniform, probs1) | |||
| sample = self.cast(sample, self._dtype) | |||
| return sample | |||
| return None | |||
| @@ -0,0 +1,200 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """basic""" | |||
| from ..cell import Cell | |||
| from ._utils.utils import calc_broadcast_shape_from_param | |||
| class Distribution(Cell): | |||
| """ | |||
| Base class for all mathematical distributions. | |||
| Args: | |||
| dtype (mindspore.dtype): type of the distribution. | |||
| name (str): name of the distribution. | |||
| param (dict): parameters used to initialize the distribution. | |||
| Note: | |||
| Derived class should override operations such as ,_mean, _prob, | |||
| and _log_prob. Functions should be called through construct when | |||
| used inside a network in the form of function name followed by | |||
| arguments. | |||
| Examples: | |||
| >>> class MyNormalDistribution(Distribution): | |||
| >>> def __init__(self): | |||
| >>> super(MyDistribution, self).__init__() | |||
| >>> self._mean_value = Tensor([2.0,3.0]) | |||
| >>> self._sd_value = Tensor([2.0,3.0]) | |||
| >>> | |||
| >>> def _mean(self): | |||
| >>> return self._mean_value | |||
| """ | |||
| def __init__(self, | |||
| dtype, | |||
| name, | |||
| param): | |||
| """ | |||
| Constructor of distribution class. | |||
| """ | |||
| super(Distribution, self).__init__() | |||
| self._name = name | |||
| self._dtype = dtype | |||
| self._parameters = {} | |||
| # parsing parameters | |||
| for k in param.keys(): | |||
| if not(k == 'self' or k.startswith('_')): | |||
| self._parameters[k] = param[k] | |||
| # some attributes | |||
| self._broadcast_shape = calc_broadcast_shape_from_param( | |||
| self._parameters) | |||
| # set the function to call according to the derived class's attributes | |||
| self._set_prob() | |||
| self._set_log_prob() | |||
| self._set_sd() | |||
| def _set_prob(self): | |||
| """ | |||
| Set probability funtion based on the availability of _prob and _log_likehood. | |||
| """ | |||
| if hasattr(self, '_prob'): | |||
| self._call_prob = self._prob | |||
| elif hasattr(self, '_log_likelihood'): | |||
| self._call_prob = self._calc_prob_from_log_likelihood | |||
| def _set_sd(self): | |||
| """ | |||
| Set standard deviation based on the availability of _sd and _var. | |||
| """ | |||
| if hasattr(self, '_sd'): | |||
| self._call_sd = self._sd | |||
| elif hasattr(self, '_var'): | |||
| self._call_sd = self._calc_sd_from_var | |||
| def _set_log_prob(self): | |||
| """ | |||
| Set log probability based on the availability of _prob and _log_likelihood. | |||
| """ | |||
| if hasattr(self, '_log_likelihood'): | |||
| self._call_log_prob = self._log_likelihood | |||
| if hasattr(self, '_prob'): | |||
| self._call_log_prob = self._calc_log_prob_from_prob | |||
| def log_likelihood(self, *args): | |||
| """ | |||
| Evaluate the log probability at the given value. | |||
| Note: | |||
| value is casted to Tensor for further calculation. | |||
| Returns: | |||
| Tensor, shape is the broadcast_shape of the distribution. | |||
| """ | |||
| return self._call_log_prob(*args) | |||
| def _calc_prob_from_log_likelihood(self, *args): | |||
| r""" | |||
| Evaluate prob from log probability. | |||
| .. math:: | |||
| probability(x) = \exp(log_likehood(x)) | |||
| """ | |||
| return self.exp(self._log_likelihood(*args)) | |||
| def prob(self, *args): | |||
| """ | |||
| Evaluate the prob (pdf or pmf) at given value. | |||
| Note: | |||
| value is casted to Tensor for further calculation. | |||
| Returns: | |||
| Tensor, shape is the broadcast_shape of the distribution. | |||
| """ | |||
| return self._call_prob(*args) | |||
| def _calc_log_prob_from_prob(self, *args): | |||
| r""" | |||
| Evaluate log probability from probability. | |||
| .. math:: | |||
| log_prob(x) = \log(prob(x)) | |||
| """ | |||
| return self.log(self._prob(*args)) | |||
| def kl_loss(self, **kwargs): | |||
| """ | |||
| Evaluate the KL divergence. Parameters of the second distribution should be | |||
| passed in through **kwargs. | |||
| Returns: | |||
| Tensor, shape is the broadcast_shape of the distribution and input distribution. | |||
| """ | |||
| return self._kl_loss(**kwargs) | |||
| def mean(self, **kwargs): | |||
| """ | |||
| Evaluate the mean. | |||
| Returns: | |||
| Tensor, shape is the broadcast_shape of the distribution. | |||
| """ | |||
| return self._mean(**kwargs) | |||
| def sd(self, **kwargs): | |||
| """ | |||
| Evaluate the standard deviation. | |||
| Returns: | |||
| Tensor, shape is the broadcast_shape of the distribution. | |||
| """ | |||
| return self._call_sd(**kwargs) | |||
| def _calc_sd_from_var(self, *args): | |||
| r""" | |||
| Evaluate log probability from probability. | |||
| .. math:: | |||
| STD(x) = \sqrt(VAR(x)) | |||
| """ | |||
| return self.sqrt(self._var(*args)) | |||
| def construct(self, *inputs): | |||
| """ | |||
| Override construct in Cell. | |||
| Args: | |||
| *inputs: inputs[0] is always the name of the function. | |||
| Notes: | |||
| Always raise RuntimeError as Distribution should not be called directly. | |||
| """ | |||
| if inputs[0] == 'log_prob': | |||
| return self._call_log_prob(*inputs) | |||
| if inputs[0] == 'prob': | |||
| return self._call_prob(*inputs) | |||
| if inputs[0] == 'kl_loss': | |||
| return self._kl_loss(*inputs) | |||
| if inputs[0] == 'mean': | |||
| return self._mean(*inputs) | |||
| if inputs[0] == 'sd': | |||
| return self._call_sd(*inputs) | |||
| if inputs[0] == 'sample': | |||
| return self._sample(*inputs) | |||
| return None | |||
| @@ -0,0 +1,169 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Normal Distribution""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater_equal_zero | |||
| from ...common import dtype as mstype | |||
| from ...context import get_context | |||
| class Normal(Distribution): | |||
| """ | |||
| Example class: Normal distribution. | |||
| Args: | |||
| mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. | |||
| sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. | |||
| seed (int): seed to use in sampling. Default: 0. | |||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | |||
| name (str): name of the distribution. Default: Normal. | |||
| Note: | |||
| Standard deviation should be greater than zero. | |||
| Examples: | |||
| >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 | |||
| >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) | |||
| >>> # The following create two independent normal distributions | |||
| >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||
| """ | |||
| def __init__(self, | |||
| mean=None, | |||
| sd=None, | |||
| seed=0, | |||
| dtype=mstype.float32, | |||
| name="Normal"): | |||
| """ | |||
| Constructor of normal distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| super(Normal, self).__init__(dtype, name, param) | |||
| if mean is not None and sd is not None: | |||
| self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) | |||
| self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) | |||
| check_greater_equal_zero(self._sd_value, "Standard deviation") | |||
| else: | |||
| self._mean_value = mean | |||
| self._sd_value = sd | |||
| #ops needed for the class | |||
| self.exp = P.Exp() | |||
| self.add = P.TensorAdd() | |||
| self.mul = P.Mul() | |||
| self.sq = P.Square() | |||
| self.log = P.Log() | |||
| self.sqrt = P.Sqrt() | |||
| self.realdiv = P.RealDiv() | |||
| self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step | |||
| self.normal = P.Normal(seed=seed) | |||
| self.shape = P.Shape() | |||
| self.zeroslike = P.ZerosLike() | |||
| self.const = P.ScalarToArray() | |||
| def extend_repr(self): | |||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | |||
| return str_info | |||
| def _expm1_by_step(self, x): | |||
| """ | |||
| Expm1 ops under GPU context. | |||
| """ | |||
| return self.add(self.exp(x), -1) | |||
| def _mean(self, name='mean', mean=None, sd=None): | |||
| """ | |||
| Mean of the distribution. | |||
| """ | |||
| if name == 'mean': | |||
| mean = self._mean_value if mean is None or sd is None else mean | |||
| return mean | |||
| return None | |||
| def _sd(self, name='sd', mean=None, sd=None): | |||
| """ | |||
| Standard deviation of the distribution. | |||
| """ | |||
| if name in ('sd', 'var'): | |||
| sd = self._sd_value if mean is None or sd is None else sd | |||
| return sd | |||
| return None | |||
| def _log_likelihood(self, name, value, mean=None, sd=None): | |||
| r""" | |||
| Evaluate log probability. | |||
| .. math:: | |||
| L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | |||
| """ | |||
| if name in ('prob', 'log_prob'): | |||
| mean = self._mean_value if mean is None else mean | |||
| sd = self._sd_value if sd is None else sd | |||
| unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), | |||
| 2. * self.sq(sd)) | |||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | |||
| return self.add(unnormalized_log_prob, neg_normalization) | |||
| return None | |||
| def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||
| r""" | |||
| Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | |||
| Args: | |||
| name (str): name of the funtion passed in from construct. Should always be "kl_loss". | |||
| dist (str): type of the distributions. Should be "Normal" in this case. | |||
| mean_b (Tensor): mean of distribution b. | |||
| sd_b (Tensor): standard deviation distribution b. | |||
| mean_a (Tensor): mean of distribution a. Default: self._mean_value. | |||
| sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. | |||
| .. math:: | |||
| KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + | |||
| 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | |||
| """ | |||
| if name == 'kl_loss' and dist == 'Normal': | |||
| mean_a = self._mean_value if mean_a is None else mean_a | |||
| sd_a = self._sd_value if sd_a is None else sd_a | |||
| diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) | |||
| squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) | |||
| return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) | |||
| return None | |||
| def _sample(self, name, shape=(), mean=None, sd=None): | |||
| """ | |||
| Sampling. | |||
| Args: | |||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||
| shape (tuple): shape of the sample. Default: (). | |||
| mean (Tensor): mean of the samples. Default: self._mean_value. | |||
| sd (Tensor): standard deviation of the samples. Default: self._sd_value. | |||
| Returns: | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| if name == 'sample': | |||
| mean = self._mean_value if mean is None else mean | |||
| sd = self._sd_value if sd is None else sd | |||
| batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) | |||
| sample_shape = shape + batch_shape | |||
| mean_zero = self.const(0.0) | |||
| sd_one = self.const(1.0) | |||
| sample_norm = self.normal(sample_shape, mean_zero, sd_one) | |||
| sample = self.add(mean, self.mul(sample_norm, sd)) | |||
| return sample | |||
| return None | |||
| @@ -0,0 +1,147 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for bernoulli distribution""" | |||
| import numpy as np | |||
| from scipy import stats | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: probability of bernoulli distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||
| @ms_function | |||
| def construct(self, x_): | |||
| return self.b('prob', x_) | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test class: log probability of bernoulli distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||
| @ms_function | |||
| def construct(self, x_): | |||
| return self.b('log_prob', x_) | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test class: kl_loss between bernoulli distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||
| @ms_function | |||
| def construct(self, x_): | |||
| return self.b('kl_loss', 'Bernoulli', x_) | |||
| class Net3(nn.Cell): | |||
| """ | |||
| Test class: mean/sd of bernoulli distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net3, self).__init__() | |||
| self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) | |||
| @ms_function | |||
| def construct(self): | |||
| return self.b('mean'), self.b('sd') | |||
| class Net4(nn.Cell): | |||
| """ | |||
| Test class: log probability of bernoulli distribution. | |||
| """ | |||
| def __init__(self, shape, seed=0): | |||
| super(Net4, self).__init__() | |||
| self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) | |||
| self.shape = shape | |||
| @ms_function | |||
| def construct(self, probs=None): | |||
| return self.b('sample', self.shape, probs) | |||
| def test_pmf(): | |||
| """ | |||
| Test pmf. | |||
| """ | |||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||
| expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) | |||
| pdf = Net() | |||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||
| output = pdf(x_) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() | |||
| def test_log_likelihood(): | |||
| """ | |||
| Test log_pmf. | |||
| """ | |||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||
| expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) | |||
| logprob = Net1() | |||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||
| output = logprob(x_) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() | |||
| def test_kl_loss(): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| probs1_a = 0.7 | |||
| probs1_b = 0.5 | |||
| probs0_a = 1 - probs1_a | |||
| probs0_b = 1 - probs1_b | |||
| expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) | |||
| kl_loss = Net2() | |||
| output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||
| def test_basics(): | |||
| """ | |||
| Test mean/standard deviation and probs. | |||
| """ | |||
| basics = Net3() | |||
| mean, sd = basics() | |||
| expect_mean = [0.5, 0.5] | |||
| assert (mean.asnumpy() == expect_mean).all() | |||
| assert (sd.asnumpy() == expect_mean).all() | |||
| b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) | |||
| probs = b.probs() | |||
| expect_probs = [0.7, 0.5] | |||
| tol = 1e-6 | |||
| assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() | |||
| def test_sample(): | |||
| """ | |||
| Test sample. | |||
| """ | |||
| shape = (2, 3) | |||
| sample = Net4(shape) | |||
| output = sample() | |||
| assert output.shape == (2, 3, 2) | |||
| @@ -0,0 +1,152 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for normal distribution""" | |||
| import numpy as np | |||
| from scipy import stats | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore import dtype | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: probability of normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||
| @ms_function | |||
| def construct(self, x_): | |||
| return self.n('prob', x_) | |||
| class Net1(nn.Cell): | |||
| """ | |||
| Test class: log probability of normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||
| @ms_function | |||
| def construct(self, x_): | |||
| return self.n('log_prob', x_) | |||
| class Net2(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||
| @ms_function | |||
| def construct(self, x_, y_): | |||
| return self.n('kl_loss', 'Normal', x_, y_) | |||
| class Net3(nn.Cell): | |||
| """ | |||
| Test class: mean/sd of normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net3, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) | |||
| @ms_function | |||
| def construct(self): | |||
| return self.n('mean'), self.n('sd') | |||
| class Net4(nn.Cell): | |||
| """ | |||
| Test class: mean/sd of normal distribution. | |||
| """ | |||
| def __init__(self, shape, seed=0): | |||
| super(Net4, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) | |||
| self.shape = shape | |||
| @ms_function | |||
| def construct(self, mean=None, sd=None): | |||
| return self.n('sample', self.shape, mean, sd) | |||
| def test_pdf(): | |||
| """ | |||
| Test pdf. | |||
| """ | |||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||
| expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) | |||
| pdf = Net() | |||
| output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||
| def test_log_likelihood(): | |||
| """ | |||
| Test log_pdf. | |||
| """ | |||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||
| expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | |||
| logprob = Net1() | |||
| output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||
| def test_kl_loss(): | |||
| """ | |||
| Test kl_loss. | |||
| """ | |||
| mean_a = np.array([3.0]).astype(np.float32) | |||
| sd_a = np.array([4.0]).astype(np.float32) | |||
| mean_b = np.array([1.0]).astype(np.float32) | |||
| sd_b = np.array([1.0]).astype(np.float32) | |||
| diff_log_scale = np.log(sd_a) - np.log(sd_b) | |||
| squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) | |||
| expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale | |||
| kl_loss = Net2() | |||
| mean = Tensor(mean_b, dtype=dtype.float32) | |||
| sd = Tensor(sd_b, dtype=dtype.float32) | |||
| output = kl_loss(mean, sd) | |||
| tol = 1e-6 | |||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||
| def test_basics(): | |||
| """ | |||
| Test mean/standard deviation. | |||
| """ | |||
| basics = Net3() | |||
| mean, sd = basics() | |||
| expect_mean = [3.0, 3.0] | |||
| expect_sd = [2.0, 4.0] | |||
| tol = 1e-6 | |||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||
| def test_sample(): | |||
| """ | |||
| Test sample. | |||
| """ | |||
| shape = (2, 3) | |||
| seed = 10 | |||
| mean = Tensor([2.0], dtype=dtype.float32) | |||
| sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) | |||
| sample = Net4(shape, seed=seed) | |||
| output = sample(mean, sd) | |||
| assert output.shape == (2, 3, 3) | |||
| @@ -0,0 +1,369 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Test nn.Distribution. | |||
| Including Normal Distribution and Bernoulli Distribution. | |||
| """ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| def test_normal_shape_errpr(): | |||
| """ | |||
| Invalid shapes. | |||
| """ | |||
| with pytest.raises(ValueError): | |||
| nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||
| def test_no_arguments(): | |||
| """ | |||
| No args passed in during initialization. | |||
| """ | |||
| n = nn.Normal() | |||
| assert isinstance(n, nn.Distribution) | |||
| b = nn.Bernoulli() | |||
| assert isinstance(b, nn.Distribution) | |||
| def test_with_arguments(): | |||
| """ | |||
| Args passed in during initialization. | |||
| """ | |||
| n = nn.Normal([3.0], [4.0], dtype=dtype.float32) | |||
| assert isinstance(n, nn.Distribution) | |||
| b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||
| assert isinstance(b, nn.Distribution) | |||
| class NormalProb(nn.Cell): | |||
| """ | |||
| Normal distribution: initialize with mean/sd. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalProb, self).__init__() | |||
| self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||
| def construct(self, value): | |||
| x = self.normal('prob', value) | |||
| y = self.normal('log_prob', value) | |||
| return x, y | |||
| def test_normal_prob(): | |||
| """ | |||
| Test pdf/log_pdf: passing value through construct. | |||
| """ | |||
| net = NormalProb() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| pdf, log_pdf = net(value) | |||
| assert isinstance(pdf, Tensor) | |||
| assert isinstance(log_pdf, Tensor) | |||
| class NormalProb1(nn.Cell): | |||
| """ | |||
| Normal distribution: initialize without mean/sd. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalProb1, self).__init__() | |||
| self.normal = nn.Normal() | |||
| def construct(self, value, mean, sd): | |||
| x = self.normal('prob', value, mean, sd) | |||
| y = self.normal('log_prob', value, mean, sd) | |||
| return x, y | |||
| def test_normal_prob1(): | |||
| """ | |||
| Test pdf/logpdf: passing mean/sd, value through construct. | |||
| """ | |||
| net = NormalProb1() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| mean = Tensor([0.0], dtype=dtype.float32) | |||
| sd = Tensor([1.0], dtype=dtype.float32) | |||
| pdf, log_pdf = net(value, mean, sd) | |||
| assert isinstance(pdf, Tensor) | |||
| assert isinstance(log_pdf, Tensor) | |||
| class NormalProb2(nn.Cell): | |||
| """ | |||
| Normal distribution: initialize with mean/sd. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalProb2, self).__init__() | |||
| self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||
| def construct(self, value, mean, sd): | |||
| x = self.normal('prob', value, mean, sd) | |||
| y = self.normal('log_prob', value, mean, sd) | |||
| return x, y | |||
| def test_normal_prob2(): | |||
| """ | |||
| Test pdf/log_pdf: passing mean/sd through construct. | |||
| Overwrite original mean/sd. | |||
| """ | |||
| net = NormalProb2() | |||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||
| mean = Tensor([0.0], dtype=dtype.float32) | |||
| sd = Tensor([1.0], dtype=dtype.float32) | |||
| pdf, log_pdf = net(value, mean, sd) | |||
| assert isinstance(pdf, Tensor) | |||
| assert isinstance(log_pdf, Tensor) | |||
| class BernoulliProb(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliProb, self).__init__() | |||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | |||
| def construct(self, value): | |||
| return self.bernoulli('prob', value) | |||
| class BernoulliLogProb(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliLogProb, self).__init__() | |||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | |||
| def construct(self, value): | |||
| return self.bernoulli('log_prob', value) | |||
| def test_bernoulli_prob(): | |||
| """ | |||
| Test pmf/log_pmf: passing value through construct. | |||
| """ | |||
| net = BernoulliProb() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| pmf = net(value) | |||
| assert isinstance(pmf, Tensor) | |||
| def test_bernoulli_log_prob(): | |||
| """ | |||
| Test pmf/log_pmf: passing value through construct. | |||
| """ | |||
| net = BernoulliLogProb() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| log_pmf = net(value) | |||
| assert isinstance(log_pmf, Tensor) | |||
| class BernoulliProb1(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize without probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliProb1, self).__init__() | |||
| self.bernoulli = nn.Bernoulli() | |||
| def construct(self, value, probs): | |||
| return self.bernoulli('prob', value, probs) | |||
| class BernoulliLogProb1(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize without probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliLogProb1, self).__init__() | |||
| self.bernoulli = nn.Bernoulli() | |||
| def construct(self, value, probs): | |||
| return self.bernoulli('log_prob', value, probs) | |||
| def test_bernoulli_prob1(): | |||
| """ | |||
| Test pmf/log_pmf: passing probs through construct. | |||
| """ | |||
| net = BernoulliProb1() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| probs = Tensor([0.3], dtype=dtype.float32) | |||
| pmf = net(value, probs) | |||
| assert isinstance(pmf, Tensor) | |||
| def test_bernoulli_log_prob1(): | |||
| """ | |||
| Test pmf/log_pmf: passing probs through construct. | |||
| """ | |||
| net = BernoulliLogProb1() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| probs = Tensor([0.3], dtype=dtype.float32) | |||
| log_pmf = net(value, probs) | |||
| assert isinstance(log_pmf, Tensor) | |||
| class BernoulliProb2(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliProb2, self).__init__() | |||
| self.bernoulli = nn.Bernoulli(0.5) | |||
| def construct(self, value, probs): | |||
| return self.bernoulli('prob', value, probs) | |||
| class BernoulliLogProb2(nn.Cell): | |||
| """ | |||
| Bernoulli distribution: initialize with probs. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliLogProb2, self).__init__() | |||
| self.bernoulli = nn.Bernoulli(0.5) | |||
| def construct(self, value, probs): | |||
| return self.bernoulli('log_prob', value, probs) | |||
| def test_bernoulli_prob2(): | |||
| """ | |||
| Test pmf/log_pmf: passing probs/value through construct. | |||
| Overwrite original probs. | |||
| """ | |||
| net = BernoulliProb2() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| probs = Tensor([0.3], dtype=dtype.float32) | |||
| pmf = net(value, probs) | |||
| assert isinstance(pmf, Tensor) | |||
| def test_bernoulli_log_prob2(): | |||
| """ | |||
| Test pmf/log_pmf: passing probs/value through construct. | |||
| Overwrite original probs. | |||
| """ | |||
| net = BernoulliLogProb2() | |||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||
| probs = Tensor([0.3], dtype=dtype.float32) | |||
| log_pmf = net(value, probs) | |||
| assert isinstance(log_pmf, Tensor) | |||
| class NormalKl(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of Normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalKl, self).__init__() | |||
| self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| return self.n('kl_loss', 'Normal', x_, y_) | |||
| class BernoulliKl(nn.Cell): | |||
| """ | |||
| Test class: kl_loss between Bernoulli distributions. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliKl, self).__init__() | |||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||
| def construct(self, x_): | |||
| return self.b('kl_loss', 'Bernoulli', x_) | |||
| def test_kl(): | |||
| """ | |||
| Test kl_loss function. | |||
| """ | |||
| nor_net = NormalKl() | |||
| mean_b = np.array([1.0]).astype(np.float32) | |||
| sd_b = np.array([1.0]).astype(np.float32) | |||
| mean = Tensor(mean_b, dtype=dtype.float32) | |||
| sd = Tensor(sd_b, dtype=dtype.float32) | |||
| loss = nor_net(mean, sd) | |||
| assert isinstance(loss, Tensor) | |||
| ber_net = BernoulliKl() | |||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||
| loss = ber_net(probs_b) | |||
| assert isinstance(loss, Tensor) | |||
| class NormalKlNoArgs(nn.Cell): | |||
| """ | |||
| Test class: kl_loss of Normal distribution. | |||
| No args during initialization. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalKlNoArgs, self).__init__() | |||
| self.n = nn.Normal(dtype=dtype.float32) | |||
| def construct(self, x_, y_, w_, v_): | |||
| return self.n('kl_loss', 'Normal', x_, y_, w_, v_) | |||
| class BernoulliKlNoArgs(nn.Cell): | |||
| """ | |||
| Test class: kl_loss between Bernoulli distributions. | |||
| No args during initialization. | |||
| """ | |||
| def __init__(self): | |||
| super(BernoulliKlNoArgs, self).__init__() | |||
| self.b = nn.Bernoulli(dtype=dtype.int32) | |||
| def construct(self, x_, y_): | |||
| return self.b('kl_loss', 'Bernoulli', x_, y_) | |||
| def test_kl_no_args(): | |||
| """ | |||
| Test kl_loss function. | |||
| """ | |||
| nor_net = NormalKlNoArgs() | |||
| mean_b = np.array([1.0]).astype(np.float32) | |||
| sd_b = np.array([1.0]).astype(np.float32) | |||
| mean_a = np.array([2.0]).astype(np.float32) | |||
| sd_a = np.array([3.0]).astype(np.float32) | |||
| mean_b = Tensor(mean_b, dtype=dtype.float32) | |||
| sd_b = Tensor(sd_b, dtype=dtype.float32) | |||
| mean_a = Tensor(mean_a, dtype=dtype.float32) | |||
| sd_a = Tensor(sd_a, dtype=dtype.float32) | |||
| loss = nor_net(mean_b, sd_b, mean_a, sd_a) | |||
| assert isinstance(loss, Tensor) | |||
| ber_net = BernoulliKlNoArgs() | |||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||
| loss = ber_net(probs_b, probs_a) | |||
| assert isinstance(loss, Tensor) | |||
| class NormalBernoulli(nn.Cell): | |||
| """ | |||
| Test class: basic mean/sd function. | |||
| """ | |||
| def __init__(self): | |||
| super(NormalBernoulli, self).__init__() | |||
| self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||
| self.b = nn.Bernoulli(0.5, dtype=dtype.int32) | |||
| def construct(self): | |||
| normal_mean = self.n('mean') | |||
| normal_sd = self.n('sd') | |||
| bernoulli_mean = self.b('mean') | |||
| bernoulli_sd = self.b('sd') | |||
| return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd | |||
| def test_bascis(): | |||
| """ | |||
| Test mean/sd functionality of Normal and Bernoulli. | |||
| """ | |||
| net = NormalBernoulli() | |||
| normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() | |||
| assert isinstance(normal_mean, Tensor) | |||
| assert isinstance(normal_sd, Tensor) | |||
| assert isinstance(bernoulli_mean, Tensor) | |||
| assert isinstance(bernoulli_sd, Tensor) | |||