| @@ -15,9 +15,9 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Utitly functions to help distribution class.""" | """Utitly functions to help distribution class.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import _utils as utils | from mindspore.ops import _utils as utils | ||||
| from ....common.tensor import Tensor | |||||
| from ....common.tensor import Tensor, MetaTensor | |||||
| from ....common.parameter import Parameter | |||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| @@ -33,15 +33,17 @@ def cast_to_tensor(t, dtype=mstype.float32): | |||||
| Cast an user input value into a Tensor of dtype. | Cast an user input value into a Tensor of dtype. | ||||
| Args: | Args: | ||||
| t (int/float/list/numpy.ndarray/Tensor). | |||||
| dtype (mindspore.dtype). | |||||
| t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. | |||||
| dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. | |||||
| Raises: | Raises: | ||||
| RuntimeError: if t cannot be cast to Tensor. | RuntimeError: if t cannot be cast to Tensor. | ||||
| Outputs: | |||||
| Returns: | |||||
| Tensor. | Tensor. | ||||
| """ | """ | ||||
| if isinstance(t, Parameter): | |||||
| return t | |||||
| if isinstance(t, Tensor): | if isinstance(t, Tensor): | ||||
| #check if the Tensor in shape of Tensor(4) | #check if the Tensor in shape of Tensor(4) | ||||
| if t.dim() == 0: | if t.dim() == 0: | ||||
| @@ -61,9 +63,9 @@ def calc_batch_size(batch_shape): | |||||
| Calculate the size of a given batch_shape. | Calculate the size of a given batch_shape. | ||||
| Args: | Args: | ||||
| batch_shape (tuple) | |||||
| batch_shape (tuple): batch shape to be calculated. | |||||
| Outputs: | |||||
| Returns: | |||||
| int. | int. | ||||
| """ | """ | ||||
| return int(np.prod(batch_shape)) | return int(np.prod(batch_shape)) | ||||
| @@ -73,23 +75,26 @@ def convert_to_batch(t, batch_shape, dtype): | |||||
| Convert a Tensor to a given batch shape. | Convert a Tensor to a given batch shape. | ||||
| Args: | Args: | ||||
| t (Tensor) | |||||
| batch_shape (tuple) | |||||
| dtype (mindspore.dtype) | |||||
| t (Tensor, Parameter): Tensor to be converted. | |||||
| batch_shape (tuple): desired batch shape. | |||||
| dtype (mindspore.dtype): desired dtype. | |||||
| Raises: | Raises: | ||||
| RuntimeError: if the converison cannot be done. | RuntimeError: if the converison cannot be done. | ||||
| Outputs: | |||||
| Returns: | |||||
| Tensor, with shape of batch_shape. | Tensor, with shape of batch_shape. | ||||
| """ | """ | ||||
| if isinstance(t, Parameter): | |||||
| return t | |||||
| t = cast_to_tensor(t, dtype) | t = cast_to_tensor(t, dtype) | ||||
| reshape = P.Reshape() | |||||
| if t.shape != batch_shape: | if t.shape != batch_shape: | ||||
| mul = calc_batch_size(batch_shape) // t.size() | mul = calc_batch_size(batch_shape) // t.size() | ||||
| if (calc_batch_size(batch_shape) % t.size()) != 0: | if (calc_batch_size(batch_shape) % t.size()) != 0: | ||||
| raise RuntimeError("Cannot cast the tensor to the given batch shape.") | raise RuntimeError("Cannot cast the tensor to the given batch shape.") | ||||
| temp = list(t.asnumpy()) * mul | temp = list(t.asnumpy()) * mul | ||||
| return reshape(Tensor(temp), batch_shape) | |||||
| temp = np.reshape(temp, batch_shape) | |||||
| return Tensor(temp, dtype) | |||||
| return t | return t | ||||
| def check_scalar_from_param(params): | def check_scalar_from_param(params): | ||||
| @@ -97,7 +102,7 @@ def check_scalar_from_param(params): | |||||
| Check if params are all scalars. | Check if params are all scalars. | ||||
| Args: | Args: | ||||
| params (dict): parameters used to initialized distribution. | |||||
| params (dict): parameters used to initialize distribution. | |||||
| Notes: String parameters are excluded. | Notes: String parameters are excluded. | ||||
| """ | """ | ||||
| @@ -116,9 +121,9 @@ def calc_broadcast_shape_from_param(params): | |||||
| Calculate the broadcast shape from params. | Calculate the broadcast shape from params. | ||||
| Args: | Args: | ||||
| params (dict): parameters used to initialized distribution. | |||||
| params (dict): parameters used to initialize distribution. | |||||
| Outputs: | |||||
| Returns: | |||||
| tuple. | tuple. | ||||
| """ | """ | ||||
| broadcast_shape = [] | broadcast_shape = [] | ||||
| @@ -127,7 +132,10 @@ def calc_broadcast_shape_from_param(params): | |||||
| continue | continue | ||||
| if value is None: | if value is None: | ||||
| return None | return None | ||||
| value_t = cast_to_tensor(value, params['dtype']) | |||||
| if isinstance(value, Parameter): | |||||
| value_t = value.default_input | |||||
| else: | |||||
| value_t = cast_to_tensor(value, params['dtype']) | |||||
| broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) | broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) | ||||
| return tuple(broadcast_shape) | return tuple(broadcast_shape) | ||||
| @@ -136,36 +144,37 @@ def check_greater_equal_zero(value, name): | |||||
| Check if the given Tensor is greater zero. | Check if the given Tensor is greater zero. | ||||
| Args: | Args: | ||||
| value (Tensor) | |||||
| value (Tensor, Parameter): value to be checked. | |||||
| name (str) : name of the value. | name (str) : name of the value. | ||||
| Raises: | Raises: | ||||
| ValueError: if the input value is less than zero. | ValueError: if the input value is less than zero. | ||||
| """ | """ | ||||
| less = P.Less() | |||||
| zeros = Tensor([0.0], dtype=value.dtype) | |||||
| value = less(value, zeros) | |||||
| if value.asnumpy().any(): | |||||
| raise ValueError('{} should be greater than zero.'.format(name)) | |||||
| if isinstance(value, Parameter): | |||||
| if isinstance(value.default_input, MetaTensor): | |||||
| return | |||||
| value = value.default_input | |||||
| comp = np.less(value.asnumpy(), np.zeros(value.shape)) | |||||
| if comp.any(): | |||||
| raise ValueError(f'{name} should be greater than zero.') | |||||
| def check_greater(a, b, name_a, name_b): | def check_greater(a, b, name_a, name_b): | ||||
| """ | """ | ||||
| Check if Tensor b is strictly greater than Tensor a. | Check if Tensor b is strictly greater than Tensor a. | ||||
| Args: | Args: | ||||
| a (Tensor) | |||||
| b (Tensor) | |||||
| a (Tensor): input tensor a. | |||||
| b (Tensor): input tensor b. | |||||
| name_a (str): name of Tensor_a. | name_a (str): name of Tensor_a. | ||||
| name_b (str): name of Tensor_b. | name_b (str): name of Tensor_b. | ||||
| Raises: | Raises: | ||||
| ValueError: if b is less than or equal to a | ValueError: if b is less than or equal to a | ||||
| """ | """ | ||||
| less = P.Less() | |||||
| value = less(a, b) | |||||
| if not value.asnumpy().all(): | |||||
| raise ValueError('{} should be less than {}'.format(name_a, name_b)) | |||||
| comp = np.less(a.asnumpy(), b.asnumpy()) | |||||
| if not comp.all(): | |||||
| raise ValueError(f'{name_a} should be less than {name_b}') | |||||
| def check_prob(p): | def check_prob(p): | ||||
| @@ -173,18 +182,18 @@ def check_prob(p): | |||||
| Check if p is a proper probability, i.e. 0 <= p <=1. | Check if p is a proper probability, i.e. 0 <= p <=1. | ||||
| Args: | Args: | ||||
| p (Tensor): value to check. | |||||
| p (Tensor, Parameter): value to be checked. | |||||
| Raises: | Raises: | ||||
| ValueError: if p is not a proper probability. | ValueError: if p is not a proper probability. | ||||
| """ | """ | ||||
| less = P.Less() | |||||
| greater = P.Greater() | |||||
| zeros = Tensor([0.0], dtype=p.dtype) | |||||
| ones = Tensor([1.0], dtype=p.dtype) | |||||
| comp = less(p, zeros) | |||||
| if comp.asnumpy().any(): | |||||
| if isinstance(p, Parameter): | |||||
| if isinstance(p.default_input, MetaTensor): | |||||
| return | |||||
| p = p.default_input | |||||
| comp = np.less(p.asnumpy(), np.zeros(p.shape)) | |||||
| if comp.any(): | |||||
| raise ValueError('Probabilities should be greater than or equal to zero') | raise ValueError('Probabilities should be greater than or equal to zero') | ||||
| comp = greater(p, ones) | |||||
| if comp.asnumpy().any(): | |||||
| comp = np.greater(p.asnumpy(), np.ones(p.shape)) | |||||
| if comp.any(): | |||||
| raise ValueError('Probabilities should be less than or equal to one') | raise ValueError('Probabilities should be less than or equal to one') | ||||
| @@ -23,21 +23,24 @@ class Bernoulli(Distribution): | |||||
| Example class: Bernoulli Distribution. | Example class: Bernoulli Distribution. | ||||
| Args: | Args: | ||||
| probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome. | |||||
| dtype (mindspore.dtype): type of the distribution, default to int32. | |||||
| probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. | |||||
| seed (int): seed to use in sampling. Default: 0. | |||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. | |||||
| name (str): name of the distribution. Default: Bernoulli. | |||||
| Note: | Note: | ||||
| probs should be proper probabilities (0 <= p <= 1). | probs should be proper probabilities (0 <= p <= 1). | ||||
| Examples: | Examples: | ||||
| >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 | >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 | ||||
| >>> b = nn.Bernoulli(0.5, dtype = dtype.int32) | |||||
| >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) | |||||
| >>> # The following create two independent Bernoulli distributions | >>> # The following create two independent Bernoulli distributions | ||||
| >>> b = nn.Bernoulli([0.7, 0.2], dtype = dtype.int32) | |||||
| >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| probs=None, | probs=None, | ||||
| seed=0, | |||||
| dtype=mstype.int32, | dtype=mstype.int32, | ||||
| name="Bernoulli"): | name="Bernoulli"): | ||||
| """ | """ | ||||
| @@ -47,7 +50,6 @@ class Bernoulli(Distribution): | |||||
| super(Bernoulli, self).__init__(dtype, name, param) | super(Bernoulli, self).__init__(dtype, name, param) | ||||
| if probs is not None: | if probs is not None: | ||||
| self._probs = cast_to_tensor(probs) | self._probs = cast_to_tensor(probs) | ||||
| # check if the input probability is valid | |||||
| check_prob(self._probs) | check_prob(self._probs) | ||||
| else: | else: | ||||
| self._probs = probs | self._probs = probs | ||||
| @@ -58,7 +60,17 @@ class Bernoulli(Distribution): | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.realdiv = P.RealDiv() | self.realdiv = P.RealDiv() | ||||
| self.shape = P.Shape() | |||||
| self.const = P.ScalarToArray() | |||||
| self.less = P.Less() | |||||
| self.cast = P.Cast() | |||||
| self.normal = P.Normal(seed=seed) | |||||
| self.erf = P.Erf() | |||||
| self.sqrt = P.Sqrt() | |||||
| def extend_repr(self): | |||||
| str_info = f'probs = {self._probs}' | |||||
| return str_info | |||||
| def probs(self): | def probs(self): | ||||
| """ | """ | ||||
| @@ -66,21 +78,25 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _mean(self): | |||||
| def _mean(self, name='mean', probs1=None): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(B) = probs1 | MEAN(B) = probs1 | ||||
| """ | """ | ||||
| if name == 'mean': | |||||
| return self._probs if probs1 is None else probs1 | |||||
| return None | |||||
| return self._probs | |||||
| def _var(self): | |||||
| def _var(self, name='var', probs1=None): | |||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| VAR(B) = probs1 * probs0 | VAR(B) = probs1 * probs0 | ||||
| """ | """ | ||||
| probs0 = self.add(1, -1 * self._probs) | |||||
| return self.mul(probs0, self._probs) | |||||
| if name in ('sd', 'var'): | |||||
| probs1 = self._probs if probs1 is None else probs1 | |||||
| probs0 = self.add(1, -1 * probs1) | |||||
| return self.mul(probs0, probs1) | |||||
| return None | |||||
| def _prob(self, name, value, probs=None): | def _prob(self, name, value, probs=None): | ||||
| r""" | r""" | ||||
| @@ -89,18 +105,20 @@ class Bernoulli(Distribution): | |||||
| Args: | Args: | ||||
| name (str): name of the function. Should be "prob" when passed in from construct. | name (str): name of the function. Should be "prob" when passed in from construct. | ||||
| value (Tensor): a Tensor composed of only zeros and ones. | value (Tensor): a Tensor composed of only zeros and ones. | ||||
| probs (Tensor): probability of outcome is 1. Default to self._probs. | |||||
| probs (Tensor): probability of outcome is 1. Default: self._probs. | |||||
| .. math:: | .. math:: | ||||
| pmf(k) = probs1 if k = 1; | pmf(k) = probs1 if k = 1; | ||||
| pmf(k) = probs0 if k = 0; | pmf(k) = probs0 if k = 0; | ||||
| """ | """ | ||||
| probs1 = self._probs if probs is None else probs | |||||
| probs0 = self.add(1, -1 * probs1) | |||||
| return self.add(self.mul(probs1, value), | |||||
| self.mul(probs0, self.add(1, -1 * value))) | |||||
| if name in ('prob', 'log_prob'): | |||||
| probs1 = self._probs if probs is None else probs | |||||
| probs0 = self.add(1, -1 * probs1) | |||||
| return self.add(self.mul(probs1, value), | |||||
| self.mul(probs0, self.add(1, -1 * value))) | |||||
| return None | |||||
| def _kl_loss(self, name, dist, probs1_b): | |||||
| def _kl_loss(self, name, dist, probs1_b, probs1_a=None): | |||||
| r""" | r""" | ||||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | ||||
| @@ -108,19 +126,42 @@ class Bernoulli(Distribution): | |||||
| name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. | name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. | ||||
| dist (str): type of the distributions. Should be "Bernoulli" in this case. | dist (str): type of the distributions. Should be "Bernoulli" in this case. | ||||
| probs1_b (Tensor): probs1 of distribution b. | probs1_b (Tensor): probs1 of distribution b. | ||||
| probs1_a (Tensor): probs1 of distribution a. Default: self._probs. | |||||
| .. math:: | .. math:: | ||||
| KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + | KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + | ||||
| probs0_a * \log(\fract{probs0_a}{probs0_b}) | probs0_a * \log(\fract{probs0_a}{probs0_b}) | ||||
| """ | """ | ||||
| if dist == 'Bernoulli': | |||||
| probs1_a = self._probs | |||||
| if name == 'kl_loss' and dist == 'Bernoulli': | |||||
| probs1_a = self._probs if probs1_a is None else probs1_a | |||||
| probs0_a = self.add(1, -1 * probs1_a) | probs0_a = self.add(1, -1 * probs1_a) | ||||
| probs0_b = self.add(1, -1 * probs1_b) | probs0_b = self.add(1, -1 * probs1_b) | ||||
| return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), | return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), | ||||
| probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) | probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) | ||||
| return None | return None | ||||
| def extend_repr(self): | |||||
| str_info = 'probs={}'.format(self._probs) | |||||
| return str_info | |||||
| def _sample(self, name, shape=(), probs=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||||
| shape (tuple): shape of the sample. Default: (). | |||||
| probs (Tensor): probs1 of the samples. Default: self._probs. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| if name == 'sample': | |||||
| probs1 = self._probs if probs is None else probs | |||||
| batch_shape = self.shape(probs1) | |||||
| sample_shape = shape + batch_shape | |||||
| mean_zero = self.const(0.0) | |||||
| sd_one = self.const(1.0) | |||||
| sqrt_two = self.sqrt(self.const(2.0)) | |||||
| sample_norm = self.normal(sample_shape, mean_zero, sd_one) | |||||
| sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) | |||||
| sample = self.less(sample_uniform, probs1) | |||||
| sample = self.cast(sample, self._dtype) | |||||
| return sample | |||||
| return None | |||||
| @@ -21,6 +21,11 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Base class for all mathematical distributions. | Base class for all mathematical distributions. | ||||
| Args: | |||||
| dtype (mindspore.dtype): type of the distribution. | |||||
| name (str): name of the distribution. | |||||
| param (dict): parameters used to initialize the distribution. | |||||
| Note: | Note: | ||||
| Derived class should override operations such as ,_mean, _prob, | Derived class should override operations such as ,_mean, _prob, | ||||
| and _log_prob. Functions should be called through construct when | and _log_prob. Functions should be called through construct when | ||||
| @@ -97,14 +102,8 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| value is casted to Tensor for further calculation. | value is casted to Tensor for further calculation. | ||||
| Args: | |||||
| name (str): name of the calling function. | |||||
| value (Tensor): values to be evaluated. | |||||
| mean (Tensor): mean of the distirbution. Default: self.mean. | |||||
| sd (Tensor): standard deviation of the distribution. Default: self.sd. | |||||
| Outputs: | |||||
| Tensor, shape: broadcast_shape of the distribution. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| """ | """ | ||||
| return self._call_log_prob(*args) | return self._call_log_prob(*args) | ||||
| @@ -114,36 +113,9 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| probability(x) = \exp(log_likehood(x)) | probability(x) = \exp(log_likehood(x)) | ||||
| Args: | |||||
| name (str): name of the calling function. | |||||
| value (Tensor): values to be evaluated. | |||||
| mean (Tensor): mean of the distribution. Default: self.mean. | |||||
| sd (Tensor): standard deviation of the distritbuion. Default: self.sd. | |||||
| """ | """ | ||||
| return self.exp(self._log_likelihood(*args)) | return self.exp(self._log_likelihood(*args)) | ||||
| def _call_prob(self, *args): | |||||
| """ | |||||
| Raises: | |||||
| NotImplementedError when derived class didn't override _prob or _log_likelihood. | |||||
| """ | |||||
| raise NotImplementedError('pdf/pmf is not implemented: {}'.format(type(self).__name__)) | |||||
| def _call_log_prob(self, *args): | |||||
| """ | |||||
| Raises: | |||||
| NotImplementedError when derived class didn't override _prob or _log_likelihood. | |||||
| """ | |||||
| raise NotImplementedError('log_probability is not implemented: {}'.format(type(self).__name__)) | |||||
| def _call_sd(self): | |||||
| """ | |||||
| Raises: | |||||
| NotImplementedError when derived class didn't override _sd or _var. | |||||
| """ | |||||
| raise NotImplementedError('standard deviation is not implemented: {}'.format(type(self).__name__)) | |||||
| def prob(self, *args): | def prob(self, *args): | ||||
| """ | """ | ||||
| Evaluate the prob (pdf or pmf) at given value. | Evaluate the prob (pdf or pmf) at given value. | ||||
| @@ -151,14 +123,8 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| value is casted to Tensor for further calculation. | value is casted to Tensor for further calculation. | ||||
| Args: | |||||
| name (str): name of the calling function. | |||||
| value (Tensor): values to be evaluated. | |||||
| mean (Tensor): mean of the distribution. | |||||
| sd (Tensor): standard deviation of the distritbuion. | |||||
| Outputs: | |||||
| Tensor, shape: broadcast_shape of the distribution. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| """ | """ | ||||
| return self._call_prob(*args) | return self._call_prob(*args) | ||||
| @@ -176,8 +142,8 @@ class Distribution(Cell): | |||||
| Evaluate the KL divergence. Parameters of the second distribution should be | Evaluate the KL divergence. Parameters of the second distribution should be | ||||
| passed in through **kwargs. | passed in through **kwargs. | ||||
| Outputs: | |||||
| Tensor, shape: broadcast_shape of the distribution and input distribution. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution and input distribution. | |||||
| """ | """ | ||||
| return self._kl_loss(**kwargs) | return self._kl_loss(**kwargs) | ||||
| @@ -185,8 +151,8 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the mean. | Evaluate the mean. | ||||
| Outputs: | |||||
| Tensor, shape: broadcast_shape of the distribution. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| """ | """ | ||||
| return self._mean(**kwargs) | return self._mean(**kwargs) | ||||
| @@ -194,19 +160,19 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the standard deviation. | Evaluate the standard deviation. | ||||
| Outputs: | |||||
| Tensor, with shape of broadcast_shape of the distribution. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| """ | """ | ||||
| return self._call_sd(**kwargs) | return self._call_sd(**kwargs) | ||||
| def _calc_sd_from_var(self, **kwargs): | |||||
| def _calc_sd_from_var(self, *args): | |||||
| r""" | r""" | ||||
| Evaluate log probability from probability. | Evaluate log probability from probability. | ||||
| .. math:: | .. math:: | ||||
| STD(x) = \sqrt(VAR(x)) | STD(x) = \sqrt(VAR(x)) | ||||
| """ | """ | ||||
| return self.sqrt(self._var(**kwargs)) | |||||
| return self.sqrt(self._var(*args)) | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| """ | """ | ||||
| @@ -226,7 +192,9 @@ class Distribution(Cell): | |||||
| if inputs[0] == 'kl_loss': | if inputs[0] == 'kl_loss': | ||||
| return self._kl_loss(*inputs) | return self._kl_loss(*inputs) | ||||
| if inputs[0] == 'mean': | if inputs[0] == 'mean': | ||||
| return self._mean() | |||||
| return self._mean(*inputs) | |||||
| if inputs[0] == 'sd': | if inputs[0] == 'sd': | ||||
| return self._call_sd() | |||||
| return self._call_sd(*inputs) | |||||
| if inputs[0] == 'sample': | |||||
| return self._sample(*inputs) | |||||
| return None | return None | ||||
| @@ -25,23 +25,27 @@ class Normal(Distribution): | |||||
| Example class: Normal distribution. | Example class: Normal distribution. | ||||
| Args: | Args: | ||||
| mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution | |||||
| standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution | |||||
| dtype (mindspore.dtype): type of the distribution | |||||
| mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. | |||||
| sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. | |||||
| seed (int): seed to use in sampling. Default: 0. | |||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | |||||
| name (str): name of the distribution. Default: Normal. | |||||
| Note: | Note: | ||||
| Standard deviation should be greater than zero. | Standard deviation should be greater than zero. | ||||
| Examples: | Examples: | ||||
| >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 | >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 | ||||
| >>> n = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) | |||||
| >>> # The following create two independent normal distributions | >>> # The following create two independent normal distributions | ||||
| >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=dtype.float32) | |||||
| >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| mean=None, | mean=None, | ||||
| sd=None, | sd=None, | ||||
| seed=0, | |||||
| dtype=mstype.float32, | dtype=mstype.float32, | ||||
| name="Normal"): | name="Normal"): | ||||
| """ | """ | ||||
| @@ -52,7 +56,6 @@ class Normal(Distribution): | |||||
| if mean is not None and sd is not None: | if mean is not None and sd is not None: | ||||
| self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) | self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) | ||||
| self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) | self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) | ||||
| #check validity of standard deviation | |||||
| check_greater_equal_zero(self._sd_value, "Standard deviation") | check_greater_equal_zero(self._sd_value, "Standard deviation") | ||||
| else: | else: | ||||
| self._mean_value = mean | self._mean_value = mean | ||||
| @@ -61,11 +64,20 @@ class Normal(Distribution): | |||||
| #ops needed for the class | #ops needed for the class | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.add = P.TensorAdd() | self.add = P.TensorAdd() | ||||
| self.mul = P.Mul() | |||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.log = P.Log() | self.log = P.Log() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.realdiv = P.RealDiv() | self.realdiv = P.RealDiv() | ||||
| self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step | self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step | ||||
| self.normal = P.Normal(seed=seed) | |||||
| self.shape = P.Shape() | |||||
| self.zeroslike = P.ZerosLike() | |||||
| self.const = P.ScalarToArray() | |||||
| def extend_repr(self): | |||||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | |||||
| return str_info | |||||
| def _expm1_by_step(self, x): | def _expm1_by_step(self, x): | ||||
| """ | """ | ||||
| @@ -73,17 +85,23 @@ class Normal(Distribution): | |||||
| """ | """ | ||||
| return self.add(self.exp(x), -1) | return self.add(self.exp(x), -1) | ||||
| def _mean(self): | |||||
| def _mean(self, name='mean', mean=None, sd=None): | |||||
| """ | """ | ||||
| Mean of the distribution. | Mean of the distribution. | ||||
| """ | """ | ||||
| return self._mean_value | |||||
| if name == 'mean': | |||||
| mean = self._mean_value if mean is None or sd is None else mean | |||||
| return mean | |||||
| return None | |||||
| def _sd(self): | |||||
| def _sd(self, name='sd', mean=None, sd=None): | |||||
| """ | """ | ||||
| Standard deviation of the distribution. | Standard deviation of the distribution. | ||||
| """ | """ | ||||
| return self._sd_value | |||||
| if name in ('sd', 'var'): | |||||
| sd = self._sd_value if mean is None or sd is None else sd | |||||
| return sd | |||||
| return None | |||||
| def _log_likelihood(self, name, value, mean=None, sd=None): | def _log_likelihood(self, name, value, mean=None, sd=None): | ||||
| r""" | r""" | ||||
| @@ -92,33 +110,60 @@ class Normal(Distribution): | |||||
| .. math:: | .. math:: | ||||
| L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | ||||
| """ | """ | ||||
| mean = self._mean_value if mean is None else mean | |||||
| sd = self._sd_value if sd is None else sd | |||||
| unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), | |||||
| 2. * self.sq(sd)) | |||||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | |||||
| return self.add(unnormalized_log_prob, neg_normalization) | |||||
| def _kl_loss(self, name, dist, mean, sd): | |||||
| if name in ('prob', 'log_prob'): | |||||
| mean = self._mean_value if mean is None else mean | |||||
| sd = self._sd_value if sd is None else sd | |||||
| unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), | |||||
| 2. * self.sq(sd)) | |||||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | |||||
| return self.add(unnormalized_log_prob, neg_normalization) | |||||
| return None | |||||
| def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||||
| r""" | r""" | ||||
| Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | ||||
| Args: | Args: | ||||
| name (str): name of the funtion passed in from construct. Should always be "kl_loss". | name (str): name of the funtion passed in from construct. Should always be "kl_loss". | ||||
| dist (str): type of the distributions. Should be "Normal" in this case. | dist (str): type of the distributions. Should be "Normal" in this case. | ||||
| mean (Tensor): mean of distribution b. | |||||
| sd (Tensor): standard deviation distribution b. | |||||
| mean_b (Tensor): mean of distribution b. | |||||
| sd_b (Tensor): standard deviation distribution b. | |||||
| mean_a (Tensor): mean of distribution a. Default: self._mean_value. | |||||
| sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. | |||||
| .. math:: | .. math:: | ||||
| KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + | KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + | ||||
| 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | ||||
| """ | """ | ||||
| if dist == 'Normal': | |||||
| diff_log_scale = self.add(self.log(self._sd_value), - self.log(sd)) | |||||
| squared_diff = self.sq(self.add(self.realdiv(self._mean_value, sd), - self.realdiv(mean, sd))) | |||||
| if name == 'kl_loss' and dist == 'Normal': | |||||
| mean_a = self._mean_value if mean_a is None else mean_a | |||||
| sd_a = self._sd_value if sd_a is None else sd_a | |||||
| diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) | |||||
| squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) | |||||
| return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) | return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) | ||||
| return None | return None | ||||
| def extend_repr(self): | |||||
| str_info = 'mean={}, standard deviation={}'.format(self._mean_value, self._sd_value) | |||||
| return str_info | |||||
| def _sample(self, name, shape=(), mean=None, sd=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||||
| shape (tuple): shape of the sample. Default: (). | |||||
| mean (Tensor): mean of the samples. Default: self._mean_value. | |||||
| sd (Tensor): standard deviation of the samples. Default: self._sd_value. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| if name == 'sample': | |||||
| mean = self._mean_value if mean is None else mean | |||||
| sd = self._sd_value if sd is None else sd | |||||
| batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) | |||||
| sample_shape = shape + batch_shape | |||||
| mean_zero = self.const(0.0) | |||||
| sd_one = self.const(1.0) | |||||
| sample_norm = self.normal(sample_shape, mean_zero, sd_one) | |||||
| sample = self.add(mean, self.mul(sample_norm, sd)) | |||||
| return sample | |||||
| return None | |||||
| @@ -65,12 +65,25 @@ class Net3(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net3, self).__init__() | super(Net3, self).__init__() | ||||
| self.b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) | |||||
| self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| return self.b('mean'), self.b('sd') | return self.b('mean'), self.b('sd') | ||||
| class Net4(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of bernoulli distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Net4, self).__init__() | |||||
| self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, probs=None): | |||||
| return self.b('sample', self.shape, probs) | |||||
| def test_pmf(): | def test_pmf(): | ||||
| """ | """ | ||||
| Test pmf. | Test pmf. | ||||
| @@ -80,10 +93,8 @@ def test_pmf(): | |||||
| pdf = Net() | pdf = Net() | ||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | ||||
| output = pdf(x_) | output = pdf(x_) | ||||
| print("expected_pmf: ", expect_pmf) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_pmf < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() | |||||
| def test_log_likelihood(): | def test_log_likelihood(): | ||||
| """ | """ | ||||
| @@ -94,10 +105,8 @@ def test_log_likelihood(): | |||||
| logprob = Net1() | logprob = Net1() | ||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | ||||
| output = logprob(x_) | output = logprob(x_) | ||||
| print("expected_log_probability: ", expect_logpmf) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_logpmf < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() | |||||
| def test_kl_loss(): | def test_kl_loss(): | ||||
| """ | """ | ||||
| @@ -110,10 +119,8 @@ def test_kl_loss(): | |||||
| expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) | expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) | ||||
| kl_loss = Net2() | kl_loss = Net2() | ||||
| output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | ||||
| print("expected_kl_loss: ", expect_kl_loss) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_kl_loss < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| def test_basics(): | def test_basics(): | ||||
| """ | """ | ||||
| @@ -121,8 +128,20 @@ def test_basics(): | |||||
| """ | """ | ||||
| basics = Net3() | basics = Net3() | ||||
| mean, sd = basics() | mean, sd = basics() | ||||
| print("mean : ", mean) | |||||
| print("sd : ", sd) | |||||
| expect_mean = [0.5, 0.5] | |||||
| assert (mean.asnumpy() == expect_mean).all() | |||||
| assert (sd.asnumpy() == expect_mean).all() | |||||
| b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) | b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) | ||||
| probs = b.probs() | probs = b.probs() | ||||
| print("probs is ", probs) | |||||
| expect_probs = [0.7, 0.5] | |||||
| tol = 1e-6 | |||||
| assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| sample = Net4(shape) | |||||
| output = sample() | |||||
| assert output.shape == (2, 3, 2) | |||||
| @@ -65,12 +65,25 @@ class Net3(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net3, self).__init__() | super(Net3, self).__init__() | ||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| return self.n('mean'), self.n('sd') | return self.n('mean'), self.n('sd') | ||||
| class Net4(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd of normal distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Net4, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, mean=None, sd=None): | |||||
| return self.n('sample', self.shape, mean, sd) | |||||
| def test_pdf(): | def test_pdf(): | ||||
| """ | """ | ||||
| Test pdf. | Test pdf. | ||||
| @@ -79,10 +92,8 @@ def test_pdf(): | |||||
| expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) | expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) | ||||
| pdf = Net() | pdf = Net() | ||||
| output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | ||||
| print("expected_pdf: ", expect_pdf) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_pdf < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||||
| def test_log_likelihood(): | def test_log_likelihood(): | ||||
| """ | """ | ||||
| @@ -92,10 +103,8 @@ def test_log_likelihood(): | |||||
| expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | ||||
| logprob = Net1() | logprob = Net1() | ||||
| output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | ||||
| print("expected_log_probability: ", expect_logpdf) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_logpdf < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||||
| def test_kl_loss(): | def test_kl_loss(): | ||||
| """ | """ | ||||
| @@ -115,10 +124,8 @@ def test_kl_loss(): | |||||
| mean = Tensor(mean_b, dtype=dtype.float32) | mean = Tensor(mean_b, dtype=dtype.float32) | ||||
| sd = Tensor(sd_b, dtype=dtype.float32) | sd = Tensor(sd_b, dtype=dtype.float32) | ||||
| output = kl_loss(mean, sd) | output = kl_loss(mean, sd) | ||||
| print("expected_kl_loss: ", expect_kl_loss) | |||||
| print("ans: ", output.asnumpy()) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (output.asnumpy() - expect_kl_loss < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| def test_basics(): | def test_basics(): | ||||
| """ | """ | ||||
| @@ -126,5 +133,20 @@ def test_basics(): | |||||
| """ | """ | ||||
| basics = Net3() | basics = Net3() | ||||
| mean, sd = basics() | mean, sd = basics() | ||||
| print("mean is ", mean) | |||||
| print("sd is ", sd) | |||||
| expect_mean = [3.0, 3.0] | |||||
| expect_sd = [2.0, 4.0] | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| seed = 10 | |||||
| mean = Tensor([2.0], dtype=dtype.float32) | |||||
| sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) | |||||
| sample = Net4(shape, seed=seed) | |||||
| output = sample(mean, sd) | |||||
| assert output.shape == (2, 3, 3) | |||||
| @@ -36,18 +36,18 @@ def test_no_arguments(): | |||||
| No args passed in during initialization. | No args passed in during initialization. | ||||
| """ | """ | ||||
| n = nn.Normal() | n = nn.Normal() | ||||
| assert isinstance(n, nn.Distribution) | |||||
| b = nn.Bernoulli() | b = nn.Bernoulli() | ||||
| print(n) | |||||
| print(b) | |||||
| assert isinstance(b, nn.Distribution) | |||||
| def test_with_arguments(): | def test_with_arguments(): | ||||
| """ | """ | ||||
| Args passed in during initialization. | Args passed in during initialization. | ||||
| """ | """ | ||||
| n = nn.Normal([3.0], [4.0], dtype=dtype.float32) | n = nn.Normal([3.0], [4.0], dtype=dtype.float32) | ||||
| assert isinstance(n, nn.Distribution) | |||||
| b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) | b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) | ||||
| print(n) | |||||
| print(b) | |||||
| assert isinstance(b, nn.Distribution) | |||||
| class NormalProb(nn.Cell): | class NormalProb(nn.Cell): | ||||
| """ | """ | ||||
| @@ -69,8 +69,8 @@ def test_normal_prob(): | |||||
| net = NormalProb() | net = NormalProb() | ||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | value = Tensor([0.5, 1.0], dtype=dtype.float32) | ||||
| pdf, log_pdf = net(value) | pdf, log_pdf = net(value) | ||||
| print("pdf: ", pdf) | |||||
| print("log_pdf: ", log_pdf) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class NormalProb1(nn.Cell): | class NormalProb1(nn.Cell): | ||||
| """ | """ | ||||
| @@ -94,9 +94,8 @@ def test_normal_prob1(): | |||||
| mean = Tensor([0.0], dtype=dtype.float32) | mean = Tensor([0.0], dtype=dtype.float32) | ||||
| sd = Tensor([1.0], dtype=dtype.float32) | sd = Tensor([1.0], dtype=dtype.float32) | ||||
| pdf, log_pdf = net(value, mean, sd) | pdf, log_pdf = net(value, mean, sd) | ||||
| print("pdf: ", pdf) | |||||
| print("log_pdf: ", log_pdf) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class NormalProb2(nn.Cell): | class NormalProb2(nn.Cell): | ||||
| """ | """ | ||||
| @@ -121,8 +120,8 @@ def test_normal_prob2(): | |||||
| mean = Tensor([0.0], dtype=dtype.float32) | mean = Tensor([0.0], dtype=dtype.float32) | ||||
| sd = Tensor([1.0], dtype=dtype.float32) | sd = Tensor([1.0], dtype=dtype.float32) | ||||
| pdf, log_pdf = net(value, mean, sd) | pdf, log_pdf = net(value, mean, sd) | ||||
| print("pdf: ", pdf) | |||||
| print("log_pdf: ", log_pdf) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class BernoulliProb(nn.Cell): | class BernoulliProb(nn.Cell): | ||||
| """ | """ | ||||
| @@ -133,9 +132,19 @@ class BernoulliProb(nn.Cell): | |||||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | ||||
| def construct(self, value): | def construct(self, value): | ||||
| x = self.bernoulli('prob', value) | |||||
| y = self.bernoulli('log_prob', value) | |||||
| return x, y | |||||
| return self.bernoulli('prob', value) | |||||
| class BernoulliLogProb(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | |||||
| def construct(self, value): | |||||
| return self.bernoulli('log_prob', value) | |||||
| def test_bernoulli_prob(): | def test_bernoulli_prob(): | ||||
| """ | """ | ||||
| @@ -143,10 +152,17 @@ def test_bernoulli_prob(): | |||||
| """ | """ | ||||
| net = BernoulliProb() | net = BernoulliProb() | ||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | ||||
| ans = net(value) | |||||
| print("pmf: ", ans) | |||||
| print("log_pmf: ", ans) | |||||
| pmf = net(value) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing value through construct. | |||||
| """ | |||||
| net = BernoulliLogProb() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| log_pmf = net(value) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class BernoulliProb1(nn.Cell): | class BernoulliProb1(nn.Cell): | ||||
| """ | """ | ||||
| @@ -157,9 +173,19 @@ class BernoulliProb1(nn.Cell): | |||||
| self.bernoulli = nn.Bernoulli() | self.bernoulli = nn.Bernoulli() | ||||
| def construct(self, value, probs): | def construct(self, value, probs): | ||||
| x = self.bernoulli('prob', value, probs) | |||||
| y = self.bernoulli('log_prob', value, probs) | |||||
| return x, y | |||||
| return self.bernoulli('prob', value, probs) | |||||
| class BernoulliLogProb1(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize without probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb1, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli() | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('log_prob', value, probs) | |||||
| def test_bernoulli_prob1(): | def test_bernoulli_prob1(): | ||||
| """ | """ | ||||
| @@ -168,10 +194,18 @@ def test_bernoulli_prob1(): | |||||
| net = BernoulliProb1() | net = BernoulliProb1() | ||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | ||||
| probs = Tensor([0.3], dtype=dtype.float32) | probs = Tensor([0.3], dtype=dtype.float32) | ||||
| ans = net(value, probs) | |||||
| print("pmf: ", ans) | |||||
| print("log_pmf: ", ans) | |||||
| pmf = net(value, probs) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob1(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs through construct. | |||||
| """ | |||||
| net = BernoulliLogProb1() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| log_pmf = net(value, probs) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class BernoulliProb2(nn.Cell): | class BernoulliProb2(nn.Cell): | ||||
| """ | """ | ||||
| @@ -182,9 +216,19 @@ class BernoulliProb2(nn.Cell): | |||||
| self.bernoulli = nn.Bernoulli(0.5) | self.bernoulli = nn.Bernoulli(0.5) | ||||
| def construct(self, value, probs): | def construct(self, value, probs): | ||||
| x = self.bernoulli('prob', value, probs) | |||||
| y = self.bernoulli('log_prob', value, probs) | |||||
| return x, y | |||||
| return self.bernoulli('prob', value, probs) | |||||
| class BernoulliLogProb2(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb2, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5) | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('log_prob', value, probs) | |||||
| def test_bernoulli_prob2(): | def test_bernoulli_prob2(): | ||||
| """ | """ | ||||
| @@ -194,9 +238,20 @@ def test_bernoulli_prob2(): | |||||
| net = BernoulliProb2() | net = BernoulliProb2() | ||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | ||||
| probs = Tensor([0.3], dtype=dtype.float32) | probs = Tensor([0.3], dtype=dtype.float32) | ||||
| ans = net(value, probs) | |||||
| print("pmf: ", ans) | |||||
| print("log_pmf: ", ans) | |||||
| pmf = net(value, probs) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob2(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs/value through construct. | |||||
| Overwrite original probs. | |||||
| """ | |||||
| net = BernoulliLogProb2() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| log_pmf = net(value, probs) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class NormalKl(nn.Cell): | class NormalKl(nn.Cell): | ||||
| """ | """ | ||||
| @@ -229,13 +284,61 @@ def test_kl(): | |||||
| sd_b = np.array([1.0]).astype(np.float32) | sd_b = np.array([1.0]).astype(np.float32) | ||||
| mean = Tensor(mean_b, dtype=dtype.float32) | mean = Tensor(mean_b, dtype=dtype.float32) | ||||
| sd = Tensor(sd_b, dtype=dtype.float32) | sd = Tensor(sd_b, dtype=dtype.float32) | ||||
| output = nor_net(mean, sd) | |||||
| print("normal-normal kl loss: ", output) | |||||
| loss = nor_net(mean, sd) | |||||
| assert isinstance(loss, Tensor) | |||||
| ber_net = BernoulliKl() | ber_net = BernoulliKl() | ||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | probs_b = Tensor([0.3], dtype=dtype.float32) | ||||
| output = ber_net(probs_b) | |||||
| print("bernoulli-bernoulli kl loss: ", output) | |||||
| loss = ber_net(probs_b) | |||||
| assert isinstance(loss, Tensor) | |||||
| class NormalKlNoArgs(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Normal distribution. | |||||
| No args during initialization. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalKlNoArgs, self).__init__() | |||||
| self.n = nn.Normal(dtype=dtype.float32) | |||||
| def construct(self, x_, y_, w_, v_): | |||||
| return self.n('kl_loss', 'Normal', x_, y_, w_, v_) | |||||
| class BernoulliKlNoArgs(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Bernoulli distributions. | |||||
| No args during initialization. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliKlNoArgs, self).__init__() | |||||
| self.b = nn.Bernoulli(dtype=dtype.int32) | |||||
| def construct(self, x_, y_): | |||||
| return self.b('kl_loss', 'Bernoulli', x_, y_) | |||||
| def test_kl_no_args(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| nor_net = NormalKlNoArgs() | |||||
| mean_b = np.array([1.0]).astype(np.float32) | |||||
| sd_b = np.array([1.0]).astype(np.float32) | |||||
| mean_a = np.array([2.0]).astype(np.float32) | |||||
| sd_a = np.array([3.0]).astype(np.float32) | |||||
| mean_b = Tensor(mean_b, dtype=dtype.float32) | |||||
| sd_b = Tensor(sd_b, dtype=dtype.float32) | |||||
| mean_a = Tensor(mean_a, dtype=dtype.float32) | |||||
| sd_a = Tensor(sd_a, dtype=dtype.float32) | |||||
| loss = nor_net(mean_b, sd_b, mean_a, sd_a) | |||||
| assert isinstance(loss, Tensor) | |||||
| ber_net = BernoulliKlNoArgs() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| loss = ber_net(probs_b, probs_a) | |||||
| assert isinstance(loss, Tensor) | |||||
| class NormalBernoulli(nn.Cell): | class NormalBernoulli(nn.Cell): | ||||
| @@ -244,7 +347,7 @@ class NormalBernoulli(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(NormalBernoulli, self).__init__() | super(NormalBernoulli, self).__init__() | ||||
| self.n = nn.Normal(3.0, 4.0, dtype=dtype.int32) | |||||
| self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| self.b = nn.Bernoulli(0.5, dtype=dtype.int32) | self.b = nn.Bernoulli(0.5, dtype=dtype.int32) | ||||
| def construct(self): | def construct(self): | ||||
| @@ -260,7 +363,7 @@ def test_bascis(): | |||||
| """ | """ | ||||
| net = NormalBernoulli() | net = NormalBernoulli() | ||||
| normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() | normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() | ||||
| print("Mean of Normal distribution: ", normal_mean) | |||||
| print("Standard deviation of Normal distribution: ", normal_sd) | |||||
| print("Mean of Bernoulli distribution: ", bernoulli_mean) | |||||
| print("Standard deviation of Bernoulli distribution: ", bernoulli_sd) | |||||
| assert isinstance(normal_mean, Tensor) | |||||
| assert isinstance(normal_sd, Tensor) | |||||
| assert isinstance(bernoulli_mean, Tensor) | |||||
| assert isinstance(bernoulli_sd, Tensor) | |||||