Merge pull request !3394 from XunDeng/pp_poc_v3tags/v0.7.0-beta
| @@ -21,7 +21,13 @@ The high-level components(Distributions) used to construct the probabilistic net | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from .normal import Normal | from .normal import Normal | ||||
| from .bernoulli import Bernoulli | from .bernoulli import Bernoulli | ||||
| from .exponential import Exponential | |||||
| from .uniform import Uniform | |||||
| from .geometric import Geometric | |||||
| __all__ = ['Distribution', | __all__ = ['Distribution', | ||||
| 'Normal', | 'Normal', | ||||
| 'Bernoulli',] | |||||
| 'Bernoulli', | |||||
| 'Exponential', | |||||
| 'Uniform', | |||||
| 'Geometric',] | |||||
| @@ -17,8 +17,11 @@ Distribution operation utility functions. | |||||
| """ | """ | ||||
| from .utils import * | from .utils import * | ||||
| __all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', | |||||
| 'calc_batch_size', 'check_greater', | |||||
| __all__ = ['convert_to_batch', | |||||
| 'cast_to_tensor', | |||||
| 'check_greater', | |||||
| 'check_greater_equal_zero', | 'check_greater_equal_zero', | ||||
| 'check_greater_zero', | |||||
| 'calc_broadcast_shape_from_param', | 'calc_broadcast_shape_from_param', | ||||
| 'check_scalar_from_param', 'check_prob'] | |||||
| 'check_scalar_from_param', | |||||
| 'check_prob'] | |||||
| @@ -20,17 +20,10 @@ from ....common.tensor import Tensor | |||||
| from ....common.parameter import Parameter | from ....common.parameter import Parameter | ||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| def check_scalar(value): | |||||
| """ | |||||
| Check if input value is a scalar. | |||||
| """ | |||||
| return np.isscalar(value) | |||||
| def cast_to_tensor(t, dtype=mstype.float32): | def cast_to_tensor(t, dtype=mstype.float32): | ||||
| """ | """ | ||||
| Cast an user input value into a Tensor of dtype. | Cast an user input value into a Tensor of dtype. | ||||
| If the input t is of type Parameter, t is directly returned as a Parameter. | |||||
| Args: | Args: | ||||
| t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. | t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. | ||||
| @@ -54,22 +47,10 @@ def cast_to_tensor(t, dtype=mstype.float32): | |||||
| return t | return t | ||||
| if isinstance(t, (list, np.ndarray)): | if isinstance(t, (list, np.ndarray)): | ||||
| return Tensor(t, dtype=dtype) | return Tensor(t, dtype=dtype) | ||||
| if check_scalar(t): | |||||
| if np.isscalar(t): | |||||
| return Tensor([t], dtype=dtype) | return Tensor([t], dtype=dtype) | ||||
| raise RuntimeError("Input type is not supported.") | raise RuntimeError("Input type is not supported.") | ||||
| def calc_batch_size(batch_shape): | |||||
| """ | |||||
| Calculate the size of a given batch_shape. | |||||
| Args: | |||||
| batch_shape (tuple): batch shape to be calculated. | |||||
| Returns: | |||||
| int. | |||||
| """ | |||||
| return int(np.prod(batch_shape)) | |||||
| def convert_to_batch(t, batch_shape, dtype): | def convert_to_batch(t, batch_shape, dtype): | ||||
| """ | """ | ||||
| Convert a Tensor to a given batch shape. | Convert a Tensor to a given batch shape. | ||||
| @@ -87,15 +68,9 @@ def convert_to_batch(t, batch_shape, dtype): | |||||
| """ | """ | ||||
| if isinstance(t, Parameter): | if isinstance(t, Parameter): | ||||
| return t | return t | ||||
| t = cast_to_tensor(t, dtype) | |||||
| if t.shape != batch_shape: | |||||
| mul = calc_batch_size(batch_shape) // t.size() | |||||
| if (calc_batch_size(batch_shape) % t.size()) != 0: | |||||
| raise RuntimeError("Cannot cast the tensor to the given batch shape.") | |||||
| temp = list(t.asnumpy()) * mul | |||||
| temp = np.reshape(temp, batch_shape) | |||||
| return Tensor(temp, dtype) | |||||
| return t | |||||
| if isinstance(t, Tensor): | |||||
| return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype) | |||||
| return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype) | |||||
| def check_scalar_from_param(params): | def check_scalar_from_param(params): | ||||
| """ | """ | ||||
| @@ -107,9 +82,11 @@ def check_scalar_from_param(params): | |||||
| Notes: String parameters are excluded. | Notes: String parameters are excluded. | ||||
| """ | """ | ||||
| for value in params.values(): | for value in params.values(): | ||||
| if isinstance(value, Parameter): | |||||
| return False | |||||
| if isinstance(value, (str, type(params['dtype']))): | if isinstance(value, (str, type(params['dtype']))): | ||||
| continue | continue | ||||
| elif check_scalar(value): | |||||
| elif np.isscalar(value): | |||||
| continue | continue | ||||
| else: | else: | ||||
| return False | return False | ||||
| @@ -157,6 +134,26 @@ def check_greater_equal_zero(value, name): | |||||
| value = value.default_input | value = value.default_input | ||||
| comp = np.less(value.asnumpy(), np.zeros(value.shape)) | comp = np.less(value.asnumpy(), np.zeros(value.shape)) | ||||
| if comp.any(): | if comp.any(): | ||||
| raise ValueError(f'{name} should be greater than ot equal to zero.') | |||||
| def check_greater_zero(value, name): | |||||
| """ | |||||
| Check if the given Tensor is strictly greater than zero. | |||||
| Args: | |||||
| value (Tensor, Parameter): value to be checked. | |||||
| name (str) : name of the value. | |||||
| Raises: | |||||
| ValueError: if the input value is less than or equal to zero. | |||||
| """ | |||||
| if isinstance(value, Parameter): | |||||
| if isinstance(value.default_input, MetaTensor): | |||||
| return | |||||
| value = value.default_input | |||||
| comp = np.less(np.zeros(value.shape), value.asnumpy()) | |||||
| if not comp.all(): | |||||
| raise ValueError(f'{name} should be greater than zero.') | raise ValueError(f'{name} should be greater than zero.') | ||||
| def check_greater(a, b, name_a, name_b): | def check_greater(a, b, name_a, name_b): | ||||
| @@ -164,14 +161,16 @@ def check_greater(a, b, name_a, name_b): | |||||
| Check if Tensor b is strictly greater than Tensor a. | Check if Tensor b is strictly greater than Tensor a. | ||||
| Args: | Args: | ||||
| a (Tensor): input tensor a. | |||||
| b (Tensor): input tensor b. | |||||
| a (Tensor, Parameter): input tensor a. | |||||
| b (Tensor, Parameter): input tensor b. | |||||
| name_a (str): name of Tensor_a. | name_a (str): name of Tensor_a. | ||||
| name_b (str): name of Tensor_b. | name_b (str): name of Tensor_b. | ||||
| Raises: | Raises: | ||||
| ValueError: if b is less than or equal to a | ValueError: if b is less than or equal to a | ||||
| """ | """ | ||||
| if isinstance(a, Parameter) or isinstance(b, Parameter): | |||||
| return | |||||
| comp = np.less(a.asnumpy(), b.asnumpy()) | comp = np.less(a.asnumpy(), b.asnumpy()) | ||||
| if not comp.all(): | if not comp.all(): | ||||
| raise ValueError(f'{name_a} should be less than {name_b}') | raise ValueError(f'{name_a} should be less than {name_b}') | ||||
| @@ -14,29 +14,75 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Bernoulli Distribution""" | """Bernoulli Distribution""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob | from ._utils.utils import cast_to_tensor, check_prob | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| class Bernoulli(Distribution): | class Bernoulli(Distribution): | ||||
| """ | """ | ||||
| Example class: Bernoulli Distribution. | |||||
| Bernoulli Distribution. | |||||
| Args: | Args: | ||||
| probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. | |||||
| probs (float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. | |||||
| seed (int): seed to use in sampling. Default: 0. | seed (int): seed to use in sampling. Default: 0. | ||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. | dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. | ||||
| name (str): name of the distribution. Default: Bernoulli. | name (str): name of the distribution. Default: Bernoulli. | ||||
| Note: | Note: | ||||
| probs should be proper probabilities (0 <= p <= 1). | probs should be proper probabilities (0 <= p <= 1). | ||||
| Dist_spec_args is probs. | |||||
| Examples: | Examples: | ||||
| >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 | |||||
| >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) | |||||
| >>> # The following create two independent Bernoulli distributions | |||||
| >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) | |||||
| >>> # To initialize a Bernoulli distribution of prob 0.5 | |||||
| >>> n = nn.Bernoulli(0.5, dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # The following creates two independent Bernoulli distributions | |||||
| >>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # A Bernoulli distribution can be initilized without arguments | |||||
| >>> # In this case, probs must be passed in through construct. | |||||
| >>> n = nn.Bernoulli(dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # To use Bernoulli distribution in a network | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32) | |||||
| >>> self.b2 = nn.Bernoulli(dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # All the following calls in construct are valid | |||||
| >>> def construct(self, value, probs_b, probs_a): | |||||
| >>> | |||||
| >>> # Similar calls can be made to other probability functions | |||||
| >>> # by replacing 'prob' with the name of the function | |||||
| >>> ans = self.b1('prob', value) | |||||
| >>> # Evaluate with the respect to distribution b | |||||
| >>> ans = self.b1('prob', value, probs_b) | |||||
| >>> | |||||
| >>> # probs must be passed in through construct | |||||
| >>> ans = self.b2('prob', value, probs_a) | |||||
| >>> | |||||
| >>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean' | |||||
| >>> # Will return [0.0] | |||||
| >>> ans = self.b1('mean') | |||||
| >>> # Will return mean_b | |||||
| >>> ans = self.b1('mean', probs_b) | |||||
| >>> | |||||
| >>> # probs must be passed in through construct | |||||
| >>> ans = self.b2('mean', probs_a) | |||||
| >>> | |||||
| >>> # Usage of 'kl_loss' and 'cross_entropy' are similar | |||||
| >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b) | |||||
| >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a) | |||||
| >>> | |||||
| >>> # Additional probs_a must be passed in through construct | |||||
| >>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) | |||||
| >>> | |||||
| >>> # Sample Usage | |||||
| >>> ans = self.b1('sample') | |||||
| >>> ans = self.b1('sample', (2,3)) | |||||
| >>> ans = self.b1('sample', (2,3), probs_b) | |||||
| >>> ans = self.b2('sample', (2,3), probs_a) | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -50,29 +96,34 @@ class Bernoulli(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| super(Bernoulli, self).__init__(dtype, name, param) | super(Bernoulli, self).__init__(dtype, name, param) | ||||
| if probs is not None: | if probs is not None: | ||||
| self._probs = cast_to_tensor(probs) | |||||
| check_prob(self._probs) | |||||
| self._probs = cast_to_tensor(probs, dtype=mstype.float32) | |||||
| check_prob(self.probs) | |||||
| else: | else: | ||||
| self._probs = probs | self._probs = probs | ||||
| self.seed = seed | self.seed = seed | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.log = P.Log() | |||||
| self.add = P.TensorAdd() | |||||
| self.mul = P.Mul() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.realdiv = P.RealDiv() | |||||
| self.shape = P.Shape() | |||||
| self.const = P.ScalarToArray() | |||||
| self.less = P.Less() | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | |||||
| self.dtypeop = P.DType() | |||||
| self.erf = P.Erf() | self.erf = P.Erf() | ||||
| self.fill = P.Fill() | |||||
| self.log = P.Log() | |||||
| self.less = P.Less() | |||||
| self.shape = P.Shape() | |||||
| self.select = P.Select() | |||||
| self.sq = P.Square() | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| str_info = f'probs = {self._probs}' | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'probs = {self.probs}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||||
| return str_info | return str_info | ||||
| @property | |||||
| def probs(self): | def probs(self): | ||||
| """ | """ | ||||
| Returns the probability for the outcome is 1. | Returns the probability for the outcome is 1. | ||||
| @@ -85,7 +136,21 @@ class Bernoulli(Distribution): | |||||
| MEAN(B) = probs1 | MEAN(B) = probs1 | ||||
| """ | """ | ||||
| if name == 'mean': | if name == 'mean': | ||||
| return self._probs if probs1 is None else probs1 | |||||
| return self.probs if probs1 is None else probs1 | |||||
| return None | |||||
| def _mode(self, name='mode', probs1=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | |||||
| """ | |||||
| if name == 'mode': | |||||
| probs1 = self.probs if probs1 is None else probs1 | |||||
| prob_type = self.dtypeop(probs1) | |||||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | |||||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | |||||
| comp = self.less(0.5, probs1) | |||||
| return self.select(comp, ones, zeros) | |||||
| return None | return None | ||||
| def _var(self, name='var', probs1=None): | def _var(self, name='var', probs1=None): | ||||
| @@ -93,10 +158,35 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(B) = probs1 * probs0 | VAR(B) = probs1 * probs0 | ||||
| """ | """ | ||||
| if name in ('sd', 'var'): | |||||
| probs1 = self._probs if probs1 is None else probs1 | |||||
| probs0 = self.add(1, -1 * probs1) | |||||
| return self.mul(probs0, probs1) | |||||
| if name in self._variance_functions: | |||||
| probs1 = self.probs if probs1 is None else probs1 | |||||
| probs0 = 1.0 - probs1 | |||||
| return probs0 * probs1 | |||||
| return None | |||||
| def _entropy(self, name='entropy', probs=None): | |||||
| r""" | |||||
| .. math:: | |||||
| H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | |||||
| """ | |||||
| if name == 'entropy': | |||||
| probs1 = self.probs if probs is None else probs | |||||
| probs0 = 1 - probs1 | |||||
| return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | |||||
| return None | |||||
| def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): | |||||
| """ | |||||
| Evaluate cross_entropy between Bernoulli distributions. | |||||
| Args: | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Bernoulli" in this case. | |||||
| probs1_b (Tensor): probs1 of distribution b. | |||||
| probs1_a (Tensor): probs1 of distribution a. Default: self.probs. | |||||
| """ | |||||
| if name == 'cross_entropy' and dist == 'Bernoulli': | |||||
| return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) | |||||
| return None | return None | ||||
| def _prob(self, name, value, probs=None): | def _prob(self, name, value, probs=None): | ||||
| @@ -106,17 +196,43 @@ class Bernoulli(Distribution): | |||||
| Args: | Args: | ||||
| name (str): name of the function. Should be "prob" when passed in from construct. | name (str): name of the function. Should be "prob" when passed in from construct. | ||||
| value (Tensor): a Tensor composed of only zeros and ones. | value (Tensor): a Tensor composed of only zeros and ones. | ||||
| probs (Tensor): probability of outcome is 1. Default: self._probs. | |||||
| probs (Tensor): probability of outcome is 1. Default: self.probs. | |||||
| .. math:: | .. math:: | ||||
| pmf(k) = probs1 if k = 1; | pmf(k) = probs1 if k = 1; | ||||
| pmf(k) = probs0 if k = 0; | pmf(k) = probs0 if k = 0; | ||||
| """ | """ | ||||
| if name in ('prob', 'log_prob'): | |||||
| probs1 = self._probs if probs is None else probs | |||||
| probs0 = self.add(1, -1 * probs1) | |||||
| return self.add(self.mul(probs1, value), | |||||
| self.mul(probs0, self.add(1, -1 * value))) | |||||
| if name in self._prob_functions: | |||||
| probs1 = self.probs if probs is None else probs | |||||
| probs0 = 1.0 - probs1 | |||||
| return (probs1 * value) + (probs0 * (1.0 - value)) | |||||
| return None | |||||
| def _cdf(self, name, value, probs=None): | |||||
| r""" | |||||
| cdf of Bernoulli distribution. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): value to be evaluated. | |||||
| probs (Tensor): probability of outcome is 1. Default: self.probs. | |||||
| .. math:: | |||||
| cdf(k) = 0 if k < 0; | |||||
| cdf(k) = probs0 if 0 <= k <1; | |||||
| cdf(k) = 1 if k >=1; | |||||
| """ | |||||
| if name in self._cdf_survival_functions: | |||||
| probs1 = self.probs if probs is None else probs | |||||
| prob_type = self.dtypeop(probs1) | |||||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | |||||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | |||||
| comp_zero = self.less(value, 0.0) | |||||
| comp_one = self.less(value, 1.0) | |||||
| zeros = self.fill(prob_type, self.shape(value), 0.0) | |||||
| ones = self.fill(prob_type, self.shape(value), 1.0) | |||||
| less_than_zero = self.select(comp_zero, zeros, probs0) | |||||
| return self.select(comp_one, less_than_zero, ones) | |||||
| return None | return None | ||||
| def _kl_loss(self, name, dist, probs1_b, probs1_a=None): | def _kl_loss(self, name, dist, probs1_b, probs1_a=None): | ||||
| @@ -124,21 +240,20 @@ class Bernoulli(Distribution): | |||||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | ||||
| Args: | Args: | ||||
| name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Bernoulli" in this case. | dist (str): type of the distributions. Should be "Bernoulli" in this case. | ||||
| probs1_b (Tensor): probs1 of distribution b. | probs1_b (Tensor): probs1 of distribution b. | ||||
| probs1_a (Tensor): probs1 of distribution a. Default: self._probs. | |||||
| probs1_a (Tensor): probs1 of distribution a. Default: self.probs. | |||||
| .. math:: | .. math:: | ||||
| KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + | KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + | ||||
| probs0_a * \log(\fract{probs0_a}{probs0_b}) | probs0_a * \log(\fract{probs0_a}{probs0_b}) | ||||
| """ | """ | ||||
| if name == 'kl_loss' and dist == 'Bernoulli': | |||||
| probs1_a = self._probs if probs1_a is None else probs1_a | |||||
| probs0_a = self.add(1, -1 * probs1_a) | |||||
| probs0_b = self.add(1, -1 * probs1_b) | |||||
| return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), | |||||
| probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) | |||||
| if name in self._divergence_functions and dist == 'Bernoulli': | |||||
| probs1_a = self.probs if probs1_a is None else probs1_a | |||||
| probs0_a = 1.0 - probs1_a | |||||
| probs0_b = 1.0 - probs1_b | |||||
| return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | |||||
| return None | return None | ||||
| def _sample(self, name, shape=(), probs=None): | def _sample(self, name, shape=(), probs=None): | ||||
| @@ -148,21 +263,17 @@ class Bernoulli(Distribution): | |||||
| Args: | Args: | ||||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | name (str): name of the function. Should always be 'sample' when passed in from construct. | ||||
| shape (tuple): shape of the sample. Default: (). | shape (tuple): shape of the sample. Default: (). | ||||
| probs (Tensor): probs1 of the samples. Default: self._probs. | |||||
| probs (Tensor): probs1 of the samples. Default: self.probs. | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| if name == 'sample': | if name == 'sample': | ||||
| probs1 = self._probs if probs is None else probs | |||||
| batch_shape = self.shape(probs1) | |||||
| sample_shape = shape + batch_shape | |||||
| mean_zero = self.const(0.0) | |||||
| sd_one = self.const(1.0) | |||||
| sqrt_two = self.sqrt(self.const(2.0)) | |||||
| sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) | |||||
| sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) | |||||
| probs1 = self.probs if probs is None else probs | |||||
| l_zero = self.const(0.0) | |||||
| h_one = self.const(1.0) | |||||
| sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) | |||||
| sample = self.less(sample_uniform, probs1) | sample = self.less(sample_uniform, probs1) | ||||
| sample = self.cast(sample, self._dtype) | |||||
| sample = self.cast(sample, self.dtype) | |||||
| return sample | return sample | ||||
| return None | return None | ||||
| @@ -14,8 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """basic""" | """basic""" | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ._utils.utils import calc_broadcast_shape_from_param | |||||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param | |||||
| class Distribution(Cell): | class Distribution(Cell): | ||||
| """ | """ | ||||
| @@ -29,19 +28,18 @@ class Distribution(Cell): | |||||
| Note: | Note: | ||||
| Derived class should override operations such as ,_mean, _prob, | Derived class should override operations such as ,_mean, _prob, | ||||
| and _log_prob. Functions should be called through construct when | and _log_prob. Functions should be called through construct when | ||||
| used inside a network in the form of function name followed by | |||||
| arguments. | |||||
| Examples: | |||||
| >>> class MyNormalDistribution(Distribution): | |||||
| >>> def __init__(self): | |||||
| >>> super(MyDistribution, self).__init__() | |||||
| >>> self._mean_value = Tensor([2.0,3.0]) | |||||
| >>> self._sd_value = Tensor([2.0,3.0]) | |||||
| >>> | |||||
| >>> def _mean(self): | |||||
| >>> return self._mean_value | |||||
| used inside a network. Arguments should be passed in through *args | |||||
| in the form of function name followed by additional arguments. | |||||
| Functions such as cdf and prob, require a value to be passed in while | |||||
| functions such as mean and sd do not require arguments other than name. | |||||
| Dist_spec_args are unique for each type of distribution. For example, mean and sd | |||||
| are the dist_spec_args for a Normal distribution. | |||||
| For all functions, passing in dist_spec_args, are optional. | |||||
| Passing in the additional dist_spec_args will make the result to be evaluated with | |||||
| new distribution specified by the dist_spec_args. But it won't change the | |||||
| original distribuion. | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| dtype, | dtype, | ||||
| @@ -61,12 +59,40 @@ class Distribution(Cell): | |||||
| self._parameters[k] = param[k] | self._parameters[k] = param[k] | ||||
| # some attributes | # some attributes | ||||
| self._broadcast_shape = calc_broadcast_shape_from_param( | self._broadcast_shape = calc_broadcast_shape_from_param( | ||||
| self._parameters) | |||||
| self.parameters) | |||||
| self._is_scalar_batch = check_scalar_from_param(self.parameters) | |||||
| # set the function to call according to the derived class's attributes | # set the function to call according to the derived class's attributes | ||||
| self._set_prob() | self._set_prob() | ||||
| self._set_log_prob() | self._set_log_prob() | ||||
| self._set_sd() | self._set_sd() | ||||
| self._set_var() | |||||
| self._set_cdf() | |||||
| self._set_survival() | |||||
| self._set_log_cdf() | |||||
| self._set_log_survival() | |||||
| self._set_cross_entropy() | |||||
| self._prob_functions = ('prob', 'log_prob') | |||||
| self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival') | |||||
| self._variance_functions = ('var', 'sd') | |||||
| self._divergence_functions = ('kl_loss', 'cross_entropy') | |||||
| @property | |||||
| def name(self): | |||||
| return self._name | |||||
| @property | |||||
| def dtype(self): | |||||
| return self._dtype | |||||
| @property | |||||
| def parameters(self): | |||||
| return self._parameters | |||||
| @property | |||||
| def is_scalar_batch(self): | |||||
| return self._is_scalar_batch | |||||
| def _set_prob(self): | def _set_prob(self): | ||||
| """ | """ | ||||
| @@ -74,8 +100,8 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| if hasattr(self, '_prob'): | if hasattr(self, '_prob'): | ||||
| self._call_prob = self._prob | self._call_prob = self._prob | ||||
| elif hasattr(self, '_log_likelihood'): | |||||
| self._call_prob = self._calc_prob_from_log_likelihood | |||||
| elif hasattr(self, '_log_prob'): | |||||
| self._call_prob = self._calc_prob_from_log_prob | |||||
| def _set_sd(self): | def _set_sd(self): | ||||
| """ | """ | ||||
| @@ -86,45 +112,100 @@ class Distribution(Cell): | |||||
| elif hasattr(self, '_var'): | elif hasattr(self, '_var'): | ||||
| self._call_sd = self._calc_sd_from_var | self._call_sd = self._calc_sd_from_var | ||||
| def _set_var(self): | |||||
| """ | |||||
| Set variance based on the availability of _sd and _var. | |||||
| """ | |||||
| if hasattr(self, '_var'): | |||||
| self._call_var = self._var | |||||
| elif hasattr(self, '_sd'): | |||||
| self._call_var = self._calc_var_from_sd | |||||
| def _set_log_prob(self): | def _set_log_prob(self): | ||||
| """ | """ | ||||
| Set log probability based on the availability of _prob and _log_likelihood. | |||||
| Set log probability based on the availability of _prob and _log_prob. | |||||
| """ | """ | ||||
| if hasattr(self, '_log_likelihood'): | |||||
| self._call_log_prob = self._log_likelihood | |||||
| if hasattr(self, '_prob'): | |||||
| if hasattr(self, '_log_prob'): | |||||
| self._call_log_prob = self._log_prob | |||||
| elif hasattr(self, '_prob'): | |||||
| self._call_log_prob = self._calc_log_prob_from_prob | self._call_log_prob = self._calc_log_prob_from_prob | ||||
| def log_likelihood(self, *args): | |||||
| def _set_cdf(self): | |||||
| """ | |||||
| Set cdf based on the availability of _cdf and _log_cdf and survival_functions. | |||||
| """ | |||||
| if hasattr(self, '_cdf'): | |||||
| self._call_cdf = self._cdf | |||||
| elif hasattr(self, '_log_cdf'): | |||||
| self._call_cdf = self._calc_cdf_from_log_cdf | |||||
| elif hasattr(self, '_survival_function'): | |||||
| self._call_cdf = self._calc_cdf_from_survival | |||||
| elif hasattr(self, '_log_survival'): | |||||
| self._call_cdf = self._calc_cdf_from_log_survival | |||||
| def _set_survival(self): | |||||
| """ | """ | ||||
| Evaluate the log probability at the given value. | |||||
| Set survival function based on the availability of _survival function and _log_survival | |||||
| and _call_cdf. | |||||
| """ | |||||
| if hasattr(self, '_survival_function'): | |||||
| self._call_survival = self._survival_function | |||||
| elif hasattr(self, '_log_survival'): | |||||
| self._call_survival = self._calc_survival_from_log_survival | |||||
| elif hasattr(self, '_call_cdf'): | |||||
| self._call_survival = self._calc_survival_from_call_cdf | |||||
| def _set_log_cdf(self): | |||||
| """ | |||||
| Set log cdf based on the availability of _log_cdf and _call_cdf. | |||||
| """ | |||||
| if hasattr(self, '_log_cdf'): | |||||
| self._call_log_cdf = self._log_cdf | |||||
| elif hasattr(self, '_call_cdf'): | |||||
| self._call_log_cdf = self._calc_log_cdf_from_call_cdf | |||||
| Note: | |||||
| value is casted to Tensor for further calculation. | |||||
| def _set_log_survival(self): | |||||
| """ | |||||
| Set log survival based on the availability of _log_survival and _call_survival. | |||||
| """ | |||||
| if hasattr(self, '_log_survival'): | |||||
| self._call_log_survival = self._log_survival | |||||
| elif hasattr(self, '_call_survival'): | |||||
| self._call_log_survival = self._calc_log_survival_from_call_survival | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| def _set_cross_entropy(self): | |||||
| """ | |||||
| Set log survival based on the availability of _cross_entropy. | |||||
| """ | |||||
| if hasattr(self, '_cross_entropy'): | |||||
| self._call_cross_entropy = self._cross_entropy | |||||
| def log_prob(self, *args): | |||||
| """ | |||||
| Evaluate the log probability(pdf or pmf) at the given value. | |||||
| Note: | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_log_prob(*args) | return self._call_log_prob(*args) | ||||
| def _calc_prob_from_log_likelihood(self, *args): | |||||
| def _calc_prob_from_log_prob(self, *args): | |||||
| r""" | r""" | ||||
| Evaluate prob from log probability. | Evaluate prob from log probability. | ||||
| .. math:: | .. math:: | ||||
| probability(x) = \exp(log_likehood(x)) | probability(x) = \exp(log_likehood(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_likelihood(*args)) | |||||
| return self.exp(self._log_prob(*args)) | |||||
| def prob(self, *args): | def prob(self, *args): | ||||
| """ | """ | ||||
| Evaluate the prob (pdf or pmf) at given value. | |||||
| Evaluate the probability (pdf or pmf) at given value. | |||||
| Note: | Note: | ||||
| value is casted to Tensor for further calculation. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_prob(*args) | return self._call_prob(*args) | ||||
| @@ -137,33 +218,154 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| return self.log(self._prob(*args)) | return self.log(self._prob(*args)) | ||||
| def kl_loss(self, **kwargs): | |||||
| def cdf(self, *args): | |||||
| """ | """ | ||||
| Evaluate the KL divergence. Parameters of the second distribution should be | |||||
| passed in through **kwargs. | |||||
| Evaluate the cdf at given value. | |||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution and input distribution. | |||||
| Note: | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._kl_loss(**kwargs) | |||||
| return self._call_cdf(*args) | |||||
| def _calc_cdf_from_log_cdf(self, *args): | |||||
| r""" | |||||
| Evaluate cdf from log_cdf. | |||||
| def mean(self, **kwargs): | |||||
| .. math:: | |||||
| cdf(x) = \exp(log_cdf(x)) | |||||
| """ | |||||
| return self.exp(self._log_cdf(*args)) | |||||
| def _calc_cdf_from_survival(self, *args): | |||||
| r""" | |||||
| Evaluate cdf from survival function. | |||||
| .. math:: | |||||
| cdf(x) = 1 - (survival_function(x)) | |||||
| """ | |||||
| return 1.0 - self._survival_function(*args) | |||||
| def _calc_cdf_from_log_survival(self, *args): | |||||
| r""" | |||||
| Evaluate cdf from log survival function. | |||||
| .. math:: | |||||
| cdf(x) = 1 - (\exp(log_survival(x))) | |||||
| """ | |||||
| return 1.0 - self.exp(self._log_survival(*args)) | |||||
| def log_cdf(self, *args): | |||||
| """ | |||||
| Evaluate the log cdf at given value. | |||||
| Note: | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | |||||
| return self._call_log_cdf(*args) | |||||
| def _calc_log_cdf_from_call_cdf(self, *args): | |||||
| r""" | |||||
| Evaluate log cdf from cdf. | |||||
| .. math:: | |||||
| log_cdf(x) = \log(cdf(x)) | |||||
| """ | |||||
| return self.log(self._call_cdf(*args)) | |||||
| def survival_function(self, *args): | |||||
| """ | |||||
| Evaluate the survival function at given value. | |||||
| Note: | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | |||||
| return self._call_survival(*args) | |||||
| def _calc_survival_from_call_cdf(self, *args): | |||||
| r""" | |||||
| Evaluate survival function from cdf. | |||||
| .. math:: | |||||
| survival_function(x) = 1 - (cdf(x)) | |||||
| """ | |||||
| return 1.0 - self._call_cdf(*args) | |||||
| def _calc_survival_from_log_survival(self, *args): | |||||
| r""" | |||||
| Evaluate survival function from log survival function. | |||||
| .. math:: | |||||
| survival(x) = \exp(survival_function(x)) | |||||
| """ | |||||
| return self.exp(self._log_survival(*args)) | |||||
| def log_survival(self, *args): | |||||
| """ | |||||
| Evaluate the log survival function at given value. | |||||
| Note: | |||||
| Args must include name of the function and value. | |||||
| Dist_spec_args are optional. | |||||
| """ | |||||
| return self._call_log_survival(*args) | |||||
| def _calc_log_survival_from_call_survival(self, *args): | |||||
| r""" | |||||
| Evaluate log survival function from survival function. | |||||
| .. math:: | |||||
| log_survival(x) = \log(survival_function(x)) | |||||
| """ | |||||
| return self.log(self._call_survival(*args)) | |||||
| def kl_loss(self, *args): | |||||
| """ | |||||
| Evaluate the KL divergence, i.e. KL(a||b). | |||||
| Note: | |||||
| Args must include name of the function, type of the distribution, parameters of distribution b. | |||||
| Parameters for distribution a are optional. | |||||
| """ | |||||
| return self._kl_loss(*args) | |||||
| def mean(self, *args): | |||||
| """ | """ | ||||
| Evaluate the mean. | Evaluate the mean. | ||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| Note: | |||||
| Args must include the name of function. Dist_spec_args are optional. | |||||
| """ | |||||
| return self._mean(*args) | |||||
| def mode(self, *args): | |||||
| """ | """ | ||||
| return self._mean(**kwargs) | |||||
| Evaluate the mode. | |||||
| def sd(self, **kwargs): | |||||
| Note: | |||||
| Args must include the name of function. Dist_spec_args are optional. | |||||
| """ | |||||
| return self._mode(*args) | |||||
| def sd(self, *args): | |||||
| """ | """ | ||||
| Evaluate the standard deviation. | Evaluate the standard deviation. | ||||
| Returns: | |||||
| Tensor, shape is the broadcast_shape of the distribution. | |||||
| Note: | |||||
| Args must include the name of function. Dist_spec_args are optional. | |||||
| """ | """ | ||||
| return self._call_sd(**kwargs) | |||||
| return self._call_sd(*args) | |||||
| def var(self, *args): | |||||
| """ | |||||
| Evaluate the variance. | |||||
| Note: | |||||
| Args must include the name of function. Dist_spec_args are optional. | |||||
| """ | |||||
| return self._call_var(*args) | |||||
| def _calc_sd_from_var(self, *args): | def _calc_sd_from_var(self, *args): | ||||
| r""" | r""" | ||||
| @@ -174,27 +376,96 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| return self.sqrt(self._var(*args)) | return self.sqrt(self._var(*args)) | ||||
| def _calc_var_from_sd(self, *args): | |||||
| r""" | |||||
| Evaluate log probability from probability. | |||||
| .. math:: | |||||
| VAR(x) = STD(x) ^ 2 | |||||
| """ | |||||
| return self.sq(self._sd(*args)) | |||||
| def entropy(self, *args): | |||||
| """ | |||||
| Evaluate the entropy. | |||||
| Note: | |||||
| Args must include the name of function. Dist_spec_args are optional. | |||||
| """ | |||||
| return self._entropy(*args) | |||||
| def cross_entropy(self, *args): | |||||
| """ | |||||
| Evaluate the cross_entropy between distribution a and b. | |||||
| Note: | |||||
| Args must include name of the function, type of the distribution, parameters of distribution b. | |||||
| Parameters for distribution a are optional. | |||||
| """ | |||||
| return self._call_cross_entropy(*args) | |||||
| def _calc_cross_entropy(self, *args): | |||||
| r""" | |||||
| Evaluate cross_entropy from entropy and kl divergence. | |||||
| .. math:: | |||||
| H(X, Y) = H(X) + KL(X||Y) | |||||
| """ | |||||
| return self._entropy(*args) + self._kl_loss(*args) | |||||
| def sample(self, *args): | |||||
| """ | |||||
| Sampling function. | |||||
| Args: | |||||
| *args (list): arguments passed in through construct. | |||||
| Note: | |||||
| Args must include name of the function. | |||||
| Shape of the sample and dist_spec_args are optional. | |||||
| """ | |||||
| return self._sample(*args) | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| """ | """ | ||||
| Override construct in Cell. | Override construct in Cell. | ||||
| Args: | |||||
| *inputs: inputs[0] is always the name of the function. | |||||
| Note: | |||||
| Names of supported functions: | |||||
| 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival' | |||||
| 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. | |||||
| Notes: | |||||
| Always raise RuntimeError as Distribution should not be called directly. | |||||
| Args: | |||||
| *inputs (list): inputs[0] is always the name of the function. | |||||
| """ | """ | ||||
| if inputs[0] == 'log_prob': | if inputs[0] == 'log_prob': | ||||
| return self._call_log_prob(*inputs) | return self._call_log_prob(*inputs) | ||||
| if inputs[0] == 'prob': | if inputs[0] == 'prob': | ||||
| return self._call_prob(*inputs) | return self._call_prob(*inputs) | ||||
| if inputs[0] == 'cdf': | |||||
| return self._call_cdf(*inputs) | |||||
| if inputs[0] == 'log_cdf': | |||||
| return self._call_log_cdf(*inputs) | |||||
| if inputs[0] == 'survival_function': | |||||
| return self._call_survival(*inputs) | |||||
| if inputs[0] == 'log_survival': | |||||
| return self._call_log_survival(*inputs) | |||||
| if inputs[0] == 'kl_loss': | if inputs[0] == 'kl_loss': | ||||
| return self._kl_loss(*inputs) | return self._kl_loss(*inputs) | ||||
| if inputs[0] == 'mean': | if inputs[0] == 'mean': | ||||
| return self._mean(*inputs) | return self._mean(*inputs) | ||||
| if inputs[0] == 'mode': | |||||
| return self._mode(*inputs) | |||||
| if inputs[0] == 'sd': | if inputs[0] == 'sd': | ||||
| return self._call_sd(*inputs) | return self._call_sd(*inputs) | ||||
| if inputs[0] == 'var': | |||||
| return self._call_var(*inputs) | |||||
| if inputs[0] == 'entropy': | |||||
| return self._entropy(*inputs) | |||||
| if inputs[0] == 'cross_entropy': | |||||
| return self._call_cross_entropy(*inputs) | |||||
| if inputs[0] == 'sample': | if inputs[0] == 'sample': | ||||
| return self._sample(*inputs) | return self._sample(*inputs) | ||||
| return None | return None | ||||
| @@ -0,0 +1,268 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Exponential Distribution""" | |||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | |||||
| from .distribution import Distribution | |||||
| from ...common import dtype as mstype | |||||
| from ._utils.utils import cast_to_tensor, check_greater_zero | |||||
| class Exponential(Distribution): | |||||
| """ | |||||
| Example class: Exponential Distribution. | |||||
| Args: | |||||
| rate (float, list, numpy.ndarray, Tensor, Parameter): inverse scale. | |||||
| seed (int): seed to use in sampling. Default: 0. | |||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | |||||
| name (str): name of the distribution. Default: Exponential. | |||||
| Note: | |||||
| rate should be strictly greater than 0. | |||||
| Dist_spec_args is rate. | |||||
| Examples: | |||||
| >>> # To initialize an Exponential distribution of rate 0.5 | |||||
| >>> n = nn.Exponential(0.5, dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # The following creates two independent Exponential distributions | |||||
| >>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # A Exponential distribution can be initilized without arguments | |||||
| >>> # In this case, rate must be passed in through construct. | |||||
| >>> n = nn.Exponential(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # To use Exponential distribution in a network | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32) | |||||
| >>> self.e2 = nn.Exponential(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # All the following calls in construct are valid | |||||
| >>> def construct(self, value, rate_b, rate_a): | |||||
| >>> | |||||
| >>> # Similar calls can be made to other probability functions | |||||
| >>> # by replacing 'prob' with the name of the function | |||||
| >>> ans = self.e1('prob', value) | |||||
| >>> # Evaluate with the respect to distribution b | |||||
| >>> ans = self.e1('prob', value, rate_b) | |||||
| >>> | |||||
| >>> # Rate must be passed in through construct | |||||
| >>> ans = self.e2('prob', value, rate_a) | |||||
| >>> | |||||
| >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' | |||||
| >>> # Will return [0.0] | |||||
| >>> ans = self.e1('mean') | |||||
| >>> # Will return mean_b | |||||
| >>> ans = self.e1('mean', rate_b) | |||||
| >>> | |||||
| >>> # Rate must be passed in through construct | |||||
| >>> ans = self.e2('mean', rate_a) | |||||
| >>> | |||||
| >>> # Usage of 'kl_loss' and 'cross_entropy' are similar | |||||
| >>> ans = self.e1('kl_loss', 'Exponential', rate_b) | |||||
| >>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a) | |||||
| >>> | |||||
| >>> # Additional rate must be passed in through construct | |||||
| >>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a) | |||||
| >>> | |||||
| >>> # Sample Usage | |||||
| >>> ans = self.e1('sample') | |||||
| >>> ans = self.e1('sample', (2,3)) | |||||
| >>> ans = self.e1('sample', (2,3), rate_b) | |||||
| >>> ans = self.e2('sample', (2,3), rate_a) | |||||
| """ | |||||
| def __init__(self, | |||||
| rate=None, | |||||
| seed=0, | |||||
| dtype=mstype.float32, | |||||
| name="Exponential"): | |||||
| """ | |||||
| Constructor of Exponential distribution. | |||||
| """ | |||||
| param = dict(locals()) | |||||
| super(Exponential, self).__init__(dtype, name, param) | |||||
| if rate is not None: | |||||
| self._rate = cast_to_tensor(rate, mstype.float32) | |||||
| check_greater_zero(self._rate, "rate") | |||||
| else: | |||||
| self._rate = rate | |||||
| self.minval = np.finfo(np.float).tiny | |||||
| # ops needed for the class | |||||
| self.const = P.ScalarToArray() | |||||
| self.dtypeop = P.DType() | |||||
| self.exp = P.Exp() | |||||
| self.fill = P.Fill() | |||||
| self.less = P.Less() | |||||
| self.log = P.Log() | |||||
| self.select = P.Select() | |||||
| self.shape = P.Shape() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.sq = P.Square() | |||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| def extend_repr(self): | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'rate = {self.rate}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||||
| return str_info | |||||
| @property | |||||
| def rate(self): | |||||
| """ | |||||
| Return rate of the distribution. | |||||
| """ | |||||
| return self._rate | |||||
| def _mean(self, name='mean', rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MEAN(EXP) = \fract{1.0}{\lambda}. | |||||
| """ | |||||
| if name == 'mean': | |||||
| rate = self.rate if rate is None else rate | |||||
| return 1.0 / rate | |||||
| return None | |||||
| def _mode(self, name='mode', rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MODE(EXP) = 0. | |||||
| """ | |||||
| if name == 'mode': | |||||
| rate = self.rate if rate is None else rate | |||||
| return self.fill(self.dtype, self.shape(rate), 0.) | |||||
| return None | |||||
| def _sd(self, name='sd', rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| sd(EXP) = \fract{1.0}{\lambda}. | |||||
| """ | |||||
| if name in self._variance_functions: | |||||
| rate = self.rate if rate is None else rate | |||||
| return 1.0 / rate | |||||
| return None | |||||
| def _entropy(self, name='entropy', rate=None): | |||||
| r""" | |||||
| .. math:: | |||||
| H(Exp) = 1 - \log(\lambda). | |||||
| """ | |||||
| rate = self.rate if rate is None else rate | |||||
| if name == 'entropy': | |||||
| return 1.0 - self.log(rate) | |||||
| return None | |||||
| def _cross_entropy(self, name, dist, rate_b, rate_a=None): | |||||
| """ | |||||
| Evaluate cross_entropy between Exponential distributions. | |||||
| Args: | |||||
| name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. | |||||
| dist (str): type of the distributions. Should be "Exponential" in this case. | |||||
| rate_b (Tensor): rate of distribution b. | |||||
| rate_a (Tensor): rate of distribution a. Default: self.rate. | |||||
| """ | |||||
| if name == 'cross_entropy' and dist == 'Exponential': | |||||
| return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a) | |||||
| return None | |||||
| def _prob(self, name, value, rate=None): | |||||
| r""" | |||||
| pdf of Exponential distribution. | |||||
| Args: | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): value to be evaluated. | |||||
| rate (Tensor): rate of the distribution. Default: self.rate. | |||||
| Note: | |||||
| Value should be greater or equal to zero. | |||||
| .. math:: | |||||
| pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | |||||
| """ | |||||
| if name in self._prob_functions: | |||||
| rate = self.rate if rate is None else rate | |||||
| prob = rate * self.exp(-1. * rate * value) | |||||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | |||||
| comp = self.less(value, zeros) | |||||
| return self.select(comp, zeros, prob) | |||||
| return None | |||||
| def _cdf(self, name, value, rate=None): | |||||
| r""" | |||||
| cdf of Exponential distribution. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): value to be evaluated. | |||||
| rate (Tensor): rate of the distribution. Default: self.rate. | |||||
| Note: | |||||
| Value should be greater or equal to zero. | |||||
| .. math:: | |||||
| cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 | |||||
| """ | |||||
| if name in self._cdf_survival_functions: | |||||
| rate = self.rate if rate is None else rate | |||||
| cdf = 1.0 - self.exp(-1. * rate * value) | |||||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | |||||
| comp = self.less(value, zeros) | |||||
| return self.select(comp, zeros, cdf) | |||||
| return None | |||||
| def _kl_loss(self, name, dist, rate_b, rate_a=None): | |||||
| """ | |||||
| Evaluate exp-exp kl divergence, i.e. KL(a||b). | |||||
| Args: | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Exponential" in this case. | |||||
| rate_b (Tensor): rate of distribution b. | |||||
| rate_a (Tensor): rate of distribution a. Default: self.rate. | |||||
| """ | |||||
| if name in self._divergence_functions and dist == 'Exponential': | |||||
| rate_a = self.rate if rate_a is None else rate_a | |||||
| return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | |||||
| return None | |||||
| def _sample(self, name, shape=(), rate=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| shape (tuple): shape of the sample. Default: (). | |||||
| rate (Tensor): rate of the distribution. Default: self.rate. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| if name == 'sample': | |||||
| rate = self.rate if rate is None else rate | |||||
| minval = self.const(self.minval) | |||||
| maxval = self.const(1.0) | |||||
| sample = self.uniform(shape + self.shape(rate), minval, maxval) | |||||
| return -self.log(sample) / rate | |||||
| return None | |||||
| @@ -0,0 +1,288 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Geometric Distribution""" | |||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | |||||
| from .distribution import Distribution | |||||
| from ._utils.utils import cast_to_tensor, check_prob | |||||
| from ...common import dtype as mstype | |||||
| class Geometric(Distribution): | |||||
| """ | |||||
| Geometric Distribution. | |||||
| It represents k+1 Bernoulli trials needed to get one success, k is the number of failures. | |||||
| Args: | |||||
| probs (float, list, numpy.ndarray, Tensor, Parameter): probability of success. | |||||
| seed (int): seed to use in sampling. Default: 0. | |||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. | |||||
| name (str): name of the distribution. Default: Geometric. | |||||
| Note: | |||||
| probs should be proper probabilities (0 <= p <= 1). | |||||
| Dist_spec_args is probs. | |||||
| Examples: | |||||
| >>> # To initialize a Geometric distribution of prob 0.5 | |||||
| >>> n = nn.Geometric(0.5, dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # The following creates two independent Geometric distributions | |||||
| >>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # A Geometric distribution can be initilized without arguments | |||||
| >>> # In this case, probs must be passed in through construct. | |||||
| >>> n = nn.Geometric(dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # To use Geometric distribution in a network | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32) | |||||
| >>> self.g2 = nn.Geometric(dtype=mstype.int32) | |||||
| >>> | |||||
| >>> # Tthe following calls are valid in construct | |||||
| >>> def construct(self, value, probs_b, probs_a): | |||||
| >>> | |||||
| >>> # Similar calls can be made to other probability functions | |||||
| >>> # by replacing 'prob' with the name of the function | |||||
| >>> ans = self.g1('prob', value) | |||||
| >>> # Evaluate with the respect to distribution b | |||||
| >>> ans = self.g1('prob', value, probs_b) | |||||
| >>> | |||||
| >>> # Probs must be passed in through construct | |||||
| >>> ans = self.g2('prob', value, probs_a) | |||||
| >>> | |||||
| >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' | |||||
| >>> # Will return [0.0] | |||||
| >>> ans = self.g1('mean') | |||||
| >>> # Will return mean_b | |||||
| >>> ans = self.g1('mean', probs_b) | |||||
| >>> | |||||
| >>> # Probs must be passed in through construct | |||||
| >>> ans = self.g2('mean', probs_a) | |||||
| >>> | |||||
| >>> # Usage of 'kl_loss' and 'cross_entropy' are similar | |||||
| >>> ans = self.g1('kl_loss', 'Geometric', probs_b) | |||||
| >>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a) | |||||
| >>> | |||||
| >>> # Additional probs must be passed in through construct | |||||
| >>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a) | |||||
| >>> | |||||
| >>> # Sample Usage | |||||
| >>> ans = self.g1('sample') | |||||
| >>> ans = self.g1('sample', (2,3)) | |||||
| >>> ans = self.g1('sample', (2,3), probs_b) | |||||
| >>> ans = self.g2('sample', (2,3), probs_a) | |||||
| """ | |||||
| def __init__(self, | |||||
| probs=None, | |||||
| seed=0, | |||||
| dtype=mstype.int32, | |||||
| name="Geometric"): | |||||
| """ | |||||
| Constructor of Geometric distribution. | |||||
| """ | |||||
| param = dict(locals()) | |||||
| super(Geometric, self).__init__(dtype, name, param) | |||||
| if probs is not None: | |||||
| self._probs = cast_to_tensor(probs, dtype=mstype.float32) | |||||
| check_prob(self._probs) | |||||
| else: | |||||
| self._probs = probs | |||||
| self.minval = np.finfo(np.float).tiny | |||||
| # ops needed for the class | |||||
| self.const = P.ScalarToArray() | |||||
| self.dtypeop = P.DType() | |||||
| self.fill = P.Fill() | |||||
| self.floor = P.Floor() | |||||
| self.issubclass = P.IsSubClass() | |||||
| self.less = P.Less() | |||||
| self.log = P.Log() | |||||
| self.pow = P.Pow() | |||||
| self.select = P.Select() | |||||
| self.shape = P.Shape() | |||||
| self.sq = P.Square() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| def extend_repr(self): | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'probs = {self.probs}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||||
| return str_info | |||||
| @property | |||||
| def probs(self): | |||||
| """ | |||||
| Returns the probability for the outcome is 1. | |||||
| """ | |||||
| return self._probs | |||||
| def _mean(self, name='mean', probs1=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MEAN(Geo) = \fratc{1 - probs1}{probs1} | |||||
| """ | |||||
| if name == 'mean': | |||||
| probs1 = self.probs if probs1 is None else probs1 | |||||
| return (1. - probs1) / probs1 | |||||
| return None | |||||
| def _mode(self, name='mode', probs1=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MODE(Geo) = 0 | |||||
| """ | |||||
| if name == 'mode': | |||||
| probs1 = self.probs if probs1 is None else probs1 | |||||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | |||||
| return None | |||||
| def _var(self, name='var', probs1=None): | |||||
| r""" | |||||
| .. math:: | |||||
| VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} | |||||
| """ | |||||
| if name in self._variance_functions: | |||||
| probs1 = self.probs if probs1 is None else probs1 | |||||
| return (1.0 - probs1) / self.sq(probs1) | |||||
| return None | |||||
| def _entropy(self, name='entropy', probs=None): | |||||
| r""" | |||||
| .. math:: | |||||
| H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | |||||
| """ | |||||
| if name == 'entropy': | |||||
| probs1 = self.probs if probs is None else probs | |||||
| probs0 = 1.0 - probs1 | |||||
| return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | |||||
| return None | |||||
| def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): | |||||
| r""" | |||||
| Evaluate cross_entropy between Geometric distributions. | |||||
| Args: | |||||
| name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. | |||||
| dist (str): type of the distributions. Should be "Geometric" in this case. | |||||
| probs1_b (Tensor): probability of success of distribution b. | |||||
| probs1_a (Tensor): probability of success of distribution a. Default: self.probs. | |||||
| """ | |||||
| if name == 'cross_entropy' and dist == 'Geometric': | |||||
| return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) | |||||
| return None | |||||
| def _prob(self, name, value, probs=None): | |||||
| r""" | |||||
| pmf of Geometric distribution. | |||||
| Args: | |||||
| name (str): name of the function. Should be "prob" when passed in from construct. | |||||
| value (Tensor): a Tensor composed of only natural numbers. | |||||
| probs (Tensor): probability of success. Default: self.probs. | |||||
| .. math:: | |||||
| pmf(k) = probs0 ^k * probs1 if k >= 0; | |||||
| pmf(k) = 0 if k < 0. | |||||
| """ | |||||
| if name in self._prob_functions: | |||||
| probs1 = self.probs if probs is None else probs | |||||
| dtype = self.dtypeop(value) | |||||
| if self.issubclass(dtype, mstype.int_): | |||||
| pass | |||||
| elif self.issubclass(dtype, mstype.float_): | |||||
| value = self.floor(value) | |||||
| else: | |||||
| return None | |||||
| pmf = self.pow((1.0 - probs1), value) * probs1 | |||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | |||||
| comp = self.less(value, zeros) | |||||
| return self.select(comp, zeros, pmf) | |||||
| return None | |||||
| def _cdf(self, name, value, probs=None): | |||||
| r""" | |||||
| cdf of Geometric distribution. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): a Tensor composed of only natural numbers. | |||||
| probs (Tensor): probability of success. Default: self.probs. | |||||
| .. math:: | |||||
| cdf(k) = 1 - probs0 ^ (k+1) if k >= 0; | |||||
| cdf(k) = 0 if k < 0. | |||||
| """ | |||||
| if name in self._cdf_survival_functions: | |||||
| probs1 = self.probs if probs is None else probs | |||||
| probs0 = 1.0 - probs1 | |||||
| dtype = self.dtypeop(value) | |||||
| if self.issubclass(dtype, mstype.int_): | |||||
| pass | |||||
| elif self.issubclass(dtype, mstype.float_): | |||||
| value = self.floor(value) | |||||
| else: | |||||
| return None | |||||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | |||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | |||||
| comp = self.less(value, zeros) | |||||
| return self.select(comp, zeros, cdf) | |||||
| return None | |||||
| def _kl_loss(self, name, dist, probs1_b, probs1_a=None): | |||||
| r""" | |||||
| Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). | |||||
| Args: | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Geometric" in this case. | |||||
| probs1_b (Tensor): probability of success of distribution b. | |||||
| probs1_a (Tensor): probability of success of distribution a. Default: self.probs. | |||||
| .. math:: | |||||
| KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) | |||||
| """ | |||||
| if name in self._divergence_functions and dist == 'Geometric': | |||||
| probs1_a = self.probs if probs1_a is None else probs1_a | |||||
| probs0_a = 1.0 - probs1_a | |||||
| probs0_b = 1.0 - probs1_b | |||||
| return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | |||||
| return None | |||||
| def _sample(self, name, shape=(), probs=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||||
| shape (tuple): shape of the sample. Default: (). | |||||
| probs (Tensor): probability of success. Default: self.probs. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| if name == 'sample': | |||||
| probs = self.probs if probs is None else probs | |||||
| minval = self.const(self.minval) | |||||
| maxval = self.const(1.0) | |||||
| sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) | |||||
| return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) | |||||
| return None | |||||
| @@ -23,24 +23,70 @@ from ...context import get_context | |||||
| class Normal(Distribution): | class Normal(Distribution): | ||||
| """ | """ | ||||
| Example class: Normal distribution. | |||||
| Normal distribution. | |||||
| Args: | Args: | ||||
| mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. | |||||
| sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. | |||||
| mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution. | |||||
| sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution. | |||||
| seed (int): seed to use in sampling. Default: 0. | seed (int): seed to use in sampling. Default: 0. | ||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | ||||
| name (str): name of the distribution. Default: Normal. | name (str): name of the distribution. Default: Normal. | ||||
| Note: | Note: | ||||
| Standard deviation should be greater than zero. | Standard deviation should be greater than zero. | ||||
| Dist_spec_args are mean and sd. | |||||
| Examples: | Examples: | ||||
| >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 | |||||
| >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) | |||||
| >>> # The following create two independent normal distributions | |||||
| >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||||
| >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 | |||||
| >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # The following creates two independent Normal distributions | |||||
| >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # A normal distribution can be initilize without arguments | |||||
| >>> # In this case, mean and sd must be passed in through construct. | |||||
| >>> n = nn.Normal(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # To use normal in a network | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32) | |||||
| >>> self.n2 = nn.Normal(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # The following calls are valid in construct | |||||
| >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): | |||||
| >>> | |||||
| >>> # Similar calls can be made to other probability functions | |||||
| >>> # by replacing 'prob' with the name of the function | |||||
| >>> ans = self.n1('prob', value) | |||||
| >>> # Evaluate with the respect to distribution b | |||||
| >>> ans = self.n1('prob', value, mean_b, sd_b) | |||||
| >>> | |||||
| >>> # mean and sd must be passed in through construct | |||||
| >>> ans = self.n2('prob', value, mean_a, sd_a) | |||||
| >>> | |||||
| >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' | |||||
| >>> # Will return [0.0] | |||||
| >>> ans = self.n1('mean') | |||||
| >>> # Will return mean_b | |||||
| >>> ans = self.n1('mean', mean_b, sd_b) | |||||
| >>> | |||||
| >>> # mean and sd must be passed in through construct | |||||
| >>> ans = self.n2('mean', mean_a, sd_a) | |||||
| >>> | |||||
| >>> # Usage of 'kl_loss' and 'cross_entropy' are similar | |||||
| >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b) | |||||
| >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) | |||||
| >>> | |||||
| >>> # Additional mean and sd must be passed in through construct | |||||
| >>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) | |||||
| >>> | |||||
| >>> # Sample Usage | |||||
| >>> ans = self.n1('sample') | |||||
| >>> ans = self.n1('sample', (2,3)) | |||||
| >>> ans = self.n1('sample', (2,3), mean_b, sd_b) | |||||
| >>> ans = self.n2('sample', (2,3), mean_a, sd_a) | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -64,27 +110,29 @@ class Normal(Distribution): | |||||
| self.seed = seed | self.seed = seed | ||||
| #ops needed for the class | #ops needed for the class | ||||
| self.const = P.ScalarToArray() | |||||
| self.erf = P.Erf() | |||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| self.add = P.TensorAdd() | |||||
| self.mul = P.Mul() | |||||
| self.sq = P.Square() | |||||
| self.log = P.Log() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.realdiv = P.RealDiv() | |||||
| self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step | self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step | ||||
| self.fill = P.Fill() | |||||
| self.log = P.Log() | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.zeroslike = P.ZerosLike() | self.zeroslike = P.ZerosLike() | ||||
| self.const = P.ScalarToArray() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||||
| return str_info | return str_info | ||||
| def _expm1_by_step(self, x): | def _expm1_by_step(self, x): | ||||
| """ | """ | ||||
| Expm1 ops under GPU context. | Expm1 ops under GPU context. | ||||
| """ | """ | ||||
| return self.add(self.exp(x), -1) | |||||
| return self.exp(x) - 1.0 | |||||
| def _mean(self, name='mean', mean=None, sd=None): | def _mean(self, name='mean', mean=None, sd=None): | ||||
| """ | """ | ||||
| @@ -95,29 +143,92 @@ class Normal(Distribution): | |||||
| return mean | return mean | ||||
| return None | return None | ||||
| def _mode(self, name='mode', mean=None, sd=None): | |||||
| """ | |||||
| Mode of the distribution. | |||||
| """ | |||||
| if name == 'mode': | |||||
| mean = self._mean_value if mean is None or sd is None else mean | |||||
| return mean | |||||
| return None | |||||
| def _sd(self, name='sd', mean=None, sd=None): | def _sd(self, name='sd', mean=None, sd=None): | ||||
| """ | """ | ||||
| Standard deviation of the distribution. | Standard deviation of the distribution. | ||||
| """ | """ | ||||
| if name in ('sd', 'var'): | |||||
| if name in self._variance_functions: | |||||
| sd = self._sd_value if mean is None or sd is None else sd | sd = self._sd_value if mean is None or sd is None else sd | ||||
| return sd | return sd | ||||
| return None | return None | ||||
| def _log_likelihood(self, name, value, mean=None, sd=None): | |||||
| def _entropy(self, name='entropy', sd=None): | |||||
| r""" | |||||
| Evaluate entropy. | |||||
| .. math:: | |||||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | |||||
| """ | |||||
| if name == 'entropy': | |||||
| sd = self._sd_value if sd is None else sd | |||||
| return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) | |||||
| return None | |||||
| def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): | |||||
| r""" | |||||
| Evaluate cross_entropy between normal distributions. | |||||
| Args: | |||||
| name (str): name of the funtion passed in from construct. Should always be "cross_entropy". | |||||
| dist (str): type of the distributions. Should be "Normal" in this case. | |||||
| mean_b (Tensor): mean of distribution b. | |||||
| sd_b (Tensor): standard deviation distribution b. | |||||
| mean_a (Tensor): mean of distribution a. Default: self._mean_value. | |||||
| sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. | |||||
| """ | |||||
| if name == 'cross_entropy' and dist == 'Normal': | |||||
| return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a) | |||||
| return None | |||||
| def _log_prob(self, name, value, mean=None, sd=None): | |||||
| r""" | r""" | ||||
| Evaluate log probability. | Evaluate log probability. | ||||
| Args: | |||||
| name (str): name of the funtion passed in from construct. | |||||
| value (Tensor): value to be evaluated. | |||||
| mean (Tensor): mean of the distribution. Default: self._mean_value. | |||||
| sd (Tensor): standard deviation the distribution. Default: self._sd_value. | |||||
| .. math:: | .. math:: | ||||
| L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | ||||
| """ | """ | ||||
| if name in ('prob', 'log_prob'): | |||||
| if name in self._prob_functions: | |||||
| mean = self._mean_value if mean is None else mean | mean = self._mean_value if mean is None else mean | ||||
| sd = self._sd_value if sd is None else sd | sd = self._sd_value if sd is None else sd | ||||
| unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), | |||||
| 2. * self.sq(sd)) | |||||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | |||||
| neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) | ||||
| return self.add(unnormalized_log_prob, neg_normalization) | |||||
| return unnormalized_log_prob + neg_normalization | |||||
| return None | |||||
| def _cdf(self, name, value, mean=None, sd=None): | |||||
| r""" | |||||
| Evaluate cdf of given value. | |||||
| Args: | |||||
| name (str): name of the funtion passed in from construct. Should always be "cdf". | |||||
| value (Tensor): value to be evaluated. | |||||
| mean (Tensor): mean of the distribution. Default: self._mean_value. | |||||
| sd (Tensor): standard deviation the distribution. Default: self._sd_value. | |||||
| .. math:: | |||||
| cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) | |||||
| """ | |||||
| if name in self._cdf_survival_functions: | |||||
| mean = self._mean_value if mean is None else mean | |||||
| sd = self._sd_value if sd is None else sd | |||||
| sqrt2 = self.sqrt(self.const(2.0)) | |||||
| adjusted = (value - mean) / (sd * sqrt2) | |||||
| return 0.5 * (1.0 + self.erf(adjusted)) | |||||
| return None | return None | ||||
| def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): | def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): | ||||
| @@ -125,7 +236,7 @@ class Normal(Distribution): | |||||
| Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | Evaluate Normal-Normal kl divergence, i.e. KL(a||b). | ||||
| Args: | Args: | ||||
| name (str): name of the funtion passed in from construct. Should always be "kl_loss". | |||||
| name (str): name of the funtion passed in from construct. | |||||
| dist (str): type of the distributions. Should be "Normal" in this case. | dist (str): type of the distributions. Should be "Normal" in this case. | ||||
| mean_b (Tensor): mean of distribution b. | mean_b (Tensor): mean of distribution b. | ||||
| sd_b (Tensor): standard deviation distribution b. | sd_b (Tensor): standard deviation distribution b. | ||||
| @@ -136,12 +247,12 @@ class Normal(Distribution): | |||||
| KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + | KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + | ||||
| 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | ||||
| """ | """ | ||||
| if name == 'kl_loss' and dist == 'Normal': | |||||
| if name in self._divergence_functions and dist == 'Normal': | |||||
| mean_a = self._mean_value if mean_a is None else mean_a | mean_a = self._mean_value if mean_a is None else mean_a | ||||
| sd_a = self._sd_value if sd_a is None else sd_a | sd_a = self._sd_value if sd_a is None else sd_a | ||||
| diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) | |||||
| squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) | |||||
| return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) | |||||
| diff_log_scale = self.log(sd_a) - self.log(sd_b) | |||||
| squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | |||||
| return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | |||||
| return None | return None | ||||
| def _sample(self, name, shape=(), mean=None, sd=None): | def _sample(self, name, shape=(), mean=None, sd=None): | ||||
| @@ -160,11 +271,11 @@ class Normal(Distribution): | |||||
| if name == 'sample': | if name == 'sample': | ||||
| mean = self._mean_value if mean is None else mean | mean = self._mean_value if mean is None else mean | ||||
| sd = self._sd_value if sd is None else sd | sd = self._sd_value if sd is None else sd | ||||
| batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) | |||||
| batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) | |||||
| sample_shape = shape + batch_shape | sample_shape = shape + batch_shape | ||||
| mean_zero = self.const(0.0) | mean_zero = self.const(0.0) | ||||
| sd_one = self.const(1.0) | sd_one = self.const(1.0) | ||||
| sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) | sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) | ||||
| sample = self.add(mean, self.mul(sample_norm, sd)) | |||||
| sample = mean + sample_norm * sd | |||||
| return sample | return sample | ||||
| return None | return None | ||||
| @@ -0,0 +1,304 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Uniform Distribution""" | |||||
| from mindspore.ops import operations as P | |||||
| from .distribution import Distribution | |||||
| from ...common import dtype as mstype | |||||
| from ._utils.utils import convert_to_batch, check_greater | |||||
| class Uniform(Distribution): | |||||
| """ | |||||
| Example class: Uniform Distribution. | |||||
| Args: | |||||
| low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution. | |||||
| high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution. | |||||
| seed (int): seed to use in sampling. Default: 0. | |||||
| dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. | |||||
| name (str): name of the distribution. Default: Uniform. | |||||
| Note: | |||||
| low should be stricly less than high. | |||||
| Dist_spec_args are high and low. | |||||
| Examples: | |||||
| >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 | |||||
| >>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # The following creates two independent Uniform distributions | |||||
| >>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # A Uniform distribution can be initilized without arguments | |||||
| >>> # In this case, high and low must be passed in through construct. | |||||
| >>> n = nn.Uniform(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # To use Uniform in a network | |||||
| >>> class net(Cell): | |||||
| >>> def __init__(self) | |||||
| >>> super(net, self).__init__(): | |||||
| >>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32) | |||||
| >>> self.u2 = nn.Uniform(dtype=mstype.float32) | |||||
| >>> | |||||
| >>> # All the following calls in construct are valid | |||||
| >>> def construct(self, value, low_b, high_b, low_a, high_a): | |||||
| >>> | |||||
| >>> # Similar calls can be made to other probability functions | |||||
| >>> # by replacing 'prob' with the name of the function | |||||
| >>> ans = self.u1('prob', value) | |||||
| >>> # Evaluate with the respect to distribution b | |||||
| >>> ans = self.u1('prob', value, low_b, high_b) | |||||
| >>> | |||||
| >>> # High and low must be passed in through construct | |||||
| >>> ans = self.u2('prob', value, low_a, high_a) | |||||
| >>> | |||||
| >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' | |||||
| >>> # Will return [0.0] | |||||
| >>> ans = self.u1('mean') | |||||
| >>> # Will return low_b | |||||
| >>> ans = self.u1('mean', low_b, high_b) | |||||
| >>> | |||||
| >>> # High and low must be passed in through construct | |||||
| >>> ans = self.u2('mean', low_a, high_a) | |||||
| >>> | |||||
| >>> # Usage of 'kl_loss' and 'cross_entropy' are similar | |||||
| >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b) | |||||
| >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) | |||||
| >>> | |||||
| >>> # Additional high and low must be passed in through construct | |||||
| >>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) | |||||
| >>> | |||||
| >>> # Sample Usage | |||||
| >>> ans = self.u1('sample') | |||||
| >>> ans = self.u1('sample', (2,3)) | |||||
| >>> ans = self.u1('sample', (2,3), low_b, high_b) | |||||
| >>> ans = self.u2('sample', (2,3), low_a, high_a) | |||||
| """ | |||||
| def __init__(self, | |||||
| low=None, | |||||
| high=None, | |||||
| seed=0, | |||||
| dtype=mstype.float32, | |||||
| name="Uniform"): | |||||
| """ | |||||
| Constructor of Uniform distribution. | |||||
| """ | |||||
| param = dict(locals()) | |||||
| super(Uniform, self).__init__(dtype, name, param) | |||||
| if low is not None and high is not None: | |||||
| self._low = convert_to_batch(low, self._broadcast_shape, dtype) | |||||
| self._high = convert_to_batch(high, self._broadcast_shape, dtype) | |||||
| check_greater(self.low, self.high, "low value", "high value") | |||||
| else: | |||||
| self._low = low | |||||
| self._high = high | |||||
| # ops needed for the class | |||||
| self.const = P.ScalarToArray() | |||||
| self.dtypeop = P.DType() | |||||
| self.exp = P.Exp() | |||||
| self.fill = P.Fill() | |||||
| self.less = P.Less() | |||||
| self.lessequal = P.LessEqual() | |||||
| self.log = P.Log() | |||||
| self.logicaland = P.LogicalAnd() | |||||
| self.select = P.Select() | |||||
| self.shape = P.Shape() | |||||
| self.sq = P.Square() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.uniform = P.UniformReal(seed=seed) | |||||
| self.zeroslike = P.ZerosLike() | |||||
| def extend_repr(self): | |||||
| if self.is_scalar_batch: | |||||
| str_info = f'low = {self.low}, high = {self.high}' | |||||
| else: | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||||
| return str_info | |||||
| @property | |||||
| def low(self): | |||||
| """ | |||||
| Return lower bound of the distribution. | |||||
| """ | |||||
| return self._low | |||||
| @property | |||||
| def high(self): | |||||
| """ | |||||
| Return upper bound of the distribution. | |||||
| """ | |||||
| return self._high | |||||
| def _range(self, name='range', low=None, high=None): | |||||
| r""" | |||||
| Return the range of the distribution. | |||||
| .. math:: | |||||
| range(U) = high -low | |||||
| """ | |||||
| if name == 'range': | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| return high - low | |||||
| return None | |||||
| def _mean(self, name='mean', low=None, high=None): | |||||
| r""" | |||||
| .. math:: | |||||
| MEAN(U) = \fract{low + high}{2}. | |||||
| """ | |||||
| if name == 'mean': | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| return (low + high) / 2. | |||||
| return None | |||||
| def _var(self, name='var', low=None, high=None): | |||||
| r""" | |||||
| .. math:: | |||||
| VAR(U) = \fract{(high -low) ^ 2}{12}. | |||||
| """ | |||||
| if name in self._variance_functions: | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| return self.sq(high - low) / 12.0 | |||||
| return None | |||||
| def _entropy(self, name='entropy', low=None, high=None): | |||||
| r""" | |||||
| .. math:: | |||||
| H(U) = \log(high - low). | |||||
| """ | |||||
| if name == 'entropy': | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| return self.log(high - low) | |||||
| return None | |||||
| def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None): | |||||
| """ | |||||
| Evaluate cross_entropy between Uniform distributoins. | |||||
| Args: | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Uniform" in this case. | |||||
| low_b (Tensor): lower bound of distribution b. | |||||
| high_b (Tensor): upper bound of distribution b. | |||||
| low_a (Tensor): lower bound of distribution a. Default: self.low. | |||||
| high_a (Tensor): upper bound of distribution a. Default: self.high. | |||||
| """ | |||||
| if name == 'cross_entropy' and dist == 'Uniform': | |||||
| return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a) | |||||
| return None | |||||
| def _prob(self, name, value, low=None, high=None): | |||||
| r""" | |||||
| pdf of Uniform distribution. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): value to be evaluated. | |||||
| low (Tensor): lower bound of the distribution. Default: self.low. | |||||
| high (Tensor): upper bound of the distribution. Default: self.high. | |||||
| .. math:: | |||||
| pdf(x) = 0 if x < low; | |||||
| pdf(x) = \fract{1.0}{high -low} if low <= x <= high; | |||||
| pdf(x) = 0 if x > high; | |||||
| """ | |||||
| if name in self._prob_functions: | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| ones = self.fill(self.dtype, self.shape(value), 1.0) | |||||
| prob = ones / (high - low) | |||||
| broadcast_shape = self.shape(prob) | |||||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | |||||
| comp_lo = self.less(value, low) | |||||
| comp_hi = self.lessequal(value, high) | |||||
| less_than_low = self.select(comp_lo, zeros, prob) | |||||
| return self.select(comp_hi, less_than_low, zeros) | |||||
| return None | |||||
| def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None): | |||||
| """ | |||||
| Evaluate uniform-uniform kl divergence, i.e. KL(a||b). | |||||
| Args: | |||||
| name (str): name of the funtion. | |||||
| dist (str): type of the distributions. Should be "Uniform" in this case. | |||||
| low_b (Tensor): lower bound of distribution b. | |||||
| high_b (Tensor): upper bound of distribution b. | |||||
| low_a (Tensor): lower bound of distribution a. Default: self.low. | |||||
| high_a (Tensor): upper bound of distribution a. Default: self.high. | |||||
| """ | |||||
| if name in self._divergence_functions and dist == 'Uniform': | |||||
| low_a = self.low if low_a is None else low_a | |||||
| high_a = self.high if high_a is None else high_a | |||||
| kl = self.log(high_b - low_b) / self.log(high_a - low_a) | |||||
| comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | |||||
| return self.select(comp, kl, self.log(self.zeroslike(kl))) | |||||
| return None | |||||
| def _cdf(self, name, value, low=None, high=None): | |||||
| r""" | |||||
| cdf of Uniform distribution. | |||||
| Args: | |||||
| name (str): name of the function. | |||||
| value (Tensor): value to be evaluated. | |||||
| low (Tensor): lower bound of the distribution. Default: self.low. | |||||
| high (Tensor): upper bound of the distribution. Default: self.high. | |||||
| .. math:: | |||||
| cdf(x) = 0 if x < low; | |||||
| cdf(x) = \fract{x - low}{high -low} if low <= x <= high; | |||||
| cdf(x) = 1 if x > high; | |||||
| """ | |||||
| if name in self._cdf_survival_functions: | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| prob = (value - low) / (high - low) | |||||
| broadcast_shape = self.shape(prob) | |||||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | |||||
| ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0) | |||||
| comp_lo = self.less(value, low) | |||||
| comp_hi = self.less(value, high) | |||||
| less_than_low = self.select(comp_lo, zeros, prob) | |||||
| return self.select(comp_hi, less_than_low, ones) | |||||
| return None | |||||
| def _sample(self, name, shape=(), low=None, high=None): | |||||
| """ | |||||
| Sampling. | |||||
| Args: | |||||
| name (str): name of the function. Should always be 'sample' when passed in from construct. | |||||
| shape (tuple): shape of the sample. Default: (). | |||||
| low (Tensor): lower bound of the distribution. Default: self.low. | |||||
| high (Tensor): upper bound of the distribution. Default: self.high. | |||||
| Returns: | |||||
| Tensor, shape is shape + batch_shape. | |||||
| """ | |||||
| if name == 'sample': | |||||
| low = self.low if low is None else low | |||||
| high = self.high if high is None else high | |||||
| broadcast_shape = self.shape(low + high) | |||||
| l_zero = self.const(0.0) | |||||
| h_one = self.const(1.0) | |||||
| sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) | |||||
| sample = (high - low) * sample_uniform + low | |||||
| return sample | |||||
| return None | |||||
| @@ -23,60 +23,113 @@ from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| class Net(nn.Cell): | |||||
| class Prob(nn.Cell): | |||||
| """ | """ | ||||
| Test class: probability of bernoulli distribution. | |||||
| Test class: probability of Bernoulli distribution. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | |||||
| super(Prob, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | ||||
| @ms_function | @ms_function | ||||
| def construct(self, x_): | def construct(self, x_): | ||||
| return self.b('prob', x_) | return self.b('prob', x_) | ||||
| class Net1(nn.Cell): | |||||
| def test_pmf(): | |||||
| """ | """ | ||||
| Test class: log probability of bernoulli distribution. | |||||
| Test pmf. | |||||
| """ | |||||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||||
| expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) | |||||
| pmf = Prob() | |||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| output = pmf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() | |||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Bernoulli distribution. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net1, self).__init__() | |||||
| super(LogProb, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | ||||
| @ms_function | @ms_function | ||||
| def construct(self, x_): | def construct(self, x_): | ||||
| return self.b('log_prob', x_) | return self.b('log_prob', x_) | ||||
| class Net2(nn.Cell): | |||||
| def test_log_likelihood(): | |||||
| """ | |||||
| Test log_pmf. | |||||
| """ | |||||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||||
| expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) | |||||
| logprob = LogProb() | |||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() | |||||
| class KL(nn.Cell): | |||||
| """ | """ | ||||
| Test class: kl_loss between bernoulli distributions. | |||||
| Test class: kl_loss between Bernoulli distributions. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net2, self).__init__() | |||||
| super(KL, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | ||||
| @ms_function | @ms_function | ||||
| def construct(self, x_): | def construct(self, x_): | ||||
| return self.b('kl_loss', 'Bernoulli', x_) | return self.b('kl_loss', 'Bernoulli', x_) | ||||
| class Net3(nn.Cell): | |||||
| def test_kl_loss(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | """ | ||||
| Test class: mean/sd of bernoulli distribution. | |||||
| probs1_a = 0.7 | |||||
| probs1_b = 0.5 | |||||
| probs0_a = 1 - probs1_a | |||||
| probs0_b = 1 - probs1_b | |||||
| expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) | |||||
| kl_loss = KL() | |||||
| output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd/mode of Bernoulli distribution. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net3, self).__init__() | |||||
| self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) | |||||
| super(Basics, self).__init__() | |||||
| self.b = nn.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| return self.b('mean'), self.b('sd') | |||||
| return self.b('mean'), self.b('sd'), self.b('mode') | |||||
| class Net4(nn.Cell): | |||||
| def test_basics(): | |||||
| """ | |||||
| Test mean/standard deviation/mode. | |||||
| """ | """ | ||||
| Test class: log probability of bernoulli distribution. | |||||
| basics = Basics() | |||||
| mean, sd, mode = basics() | |||||
| expect_mean = [0.3, 0.5, 0.7] | |||||
| expect_sd = np.sqrt(np.multiply([0.7, 0.5, 0.3], [0.3, 0.5, 0.7])) | |||||
| expect_mode = [0.0, 0.0, 1.0] | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Bernoulli distribution. | |||||
| """ | """ | ||||
| def __init__(self, shape, seed=0): | def __init__(self, shape, seed=0): | ||||
| super(Net4, self).__init__() | |||||
| super(Sampling, self).__init__() | |||||
| self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) | self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) | ||||
| self.shape = shape | self.shape = shape | ||||
| @@ -84,64 +137,159 @@ class Net4(nn.Cell): | |||||
| def construct(self, probs=None): | def construct(self, probs=None): | ||||
| return self.b('sample', self.shape, probs) | return self.b('sample', self.shape, probs) | ||||
| def test_pmf(): | |||||
| def test_sample(): | |||||
| """ | """ | ||||
| Test pmf. | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| sample = Sampling(shape) | |||||
| output = sample() | |||||
| assert output.shape == (2, 3, 2) | |||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.b('cdf', x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | """ | ||||
| bernoulli_benchmark = stats.bernoulli(0.7) | bernoulli_benchmark = stats.bernoulli(0.7) | ||||
| expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) | |||||
| pdf = Net() | |||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| output = pdf(x_) | |||||
| expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32) | |||||
| x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| cdf = CDF() | |||||
| output = cdf(x_) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| def test_log_likelihood(): | |||||
| class LogCDF(nn.Cell): | |||||
| """ | """ | ||||
| Test log_pmf. | |||||
| Test class: log cdf of bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.b('log_cdf', x_) | |||||
| def test_logcdf(): | |||||
| """ | |||||
| Test log_cdf. | |||||
| """ | """ | ||||
| bernoulli_benchmark = stats.bernoulli(0.7) | bernoulli_benchmark = stats.bernoulli(0.7) | ||||
| expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) | |||||
| logprob = Net1() | |||||
| x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| expect_logcdf = bernoulli_benchmark.logcdf([0, 0, 1, 0, 1]).astype(np.float32) | |||||
| x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) | |||||
| logcdf = LogCDF() | |||||
| output = logcdf(x_) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| def test_kl_loss(): | |||||
| class SF(nn.Cell): | |||||
| """ | """ | ||||
| Test kl_loss. | |||||
| Test class: survival function of Bernoulli distributions. | |||||
| """ | """ | ||||
| probs1_a = 0.7 | |||||
| probs1_b = 0.5 | |||||
| probs0_a = 1 - probs1_a | |||||
| probs0_b = 1 - probs1_b | |||||
| expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) | |||||
| kl_loss = Net2() | |||||
| output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.b('survival_function', x_) | |||||
| def test_survival(): | |||||
| """ | |||||
| Test survival funciton. | |||||
| """ | |||||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||||
| expect_survival = bernoulli_benchmark.sf([0, 1, 1, 0, 0]).astype(np.float32) | |||||
| x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) | |||||
| sf = SF() | |||||
| output = sf(x_) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| def test_basics(): | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survival function of Bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.b('log_survival', x_) | |||||
| def test_log_survival(): | |||||
| """ | """ | ||||
| Test mean/standard deviation and probs. | |||||
| Test log survival funciton. | |||||
| """ | """ | ||||
| basics = Net3() | |||||
| mean, sd = basics() | |||||
| expect_mean = [0.5, 0.5] | |||||
| assert (mean.asnumpy() == expect_mean).all() | |||||
| assert (sd.asnumpy() == expect_mean).all() | |||||
| b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) | |||||
| probs = b.probs() | |||||
| expect_probs = [0.7, 0.5] | |||||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||||
| expect_logsurvival = bernoulli_benchmark.logsf([-1, 0.9, 0, 0, 0]).astype(np.float32) | |||||
| x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]).astype(np.float32), dtype=dtype.float32) | |||||
| log_sf = LogSF() | |||||
| output = log_sf(x_) | |||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() | |||||
| assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() | |||||
| def test_sample(): | |||||
| class EntropyH(nn.Cell): | |||||
| """ | """ | ||||
| Test sample. | |||||
| Test class: entropy of Bernoulli distributions. | |||||
| """ | """ | ||||
| shape = (2, 3) | |||||
| sample = Net4(shape) | |||||
| output = sample() | |||||
| assert output.shape == (2, 3, 2) | |||||
| def __init__(self): | |||||
| super(EntropyH, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.b('entropy') | |||||
| def test_entropy(): | |||||
| """ | |||||
| Test entropy. | |||||
| """ | |||||
| bernoulli_benchmark = stats.bernoulli(0.7) | |||||
| expect_entropy = bernoulli_benchmark.entropy().astype(np.float32) | |||||
| entropy = EntropyH() | |||||
| output = entropy() | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||||
| class CrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross entropy between bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| entropy = self.b('entropy') | |||||
| kl_loss = self.b('kl_loss', 'Bernoulli', x_) | |||||
| h_sum_kl = entropy + kl_loss | |||||
| cross_entropy = self.b('cross_entropy', 'Bernoulli', x_) | |||||
| return h_sum_kl - cross_entropy | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy. | |||||
| """ | |||||
| cross_entropy = CrossEntropy() | |||||
| prob = Tensor([0.3], dtype=dtype.float32) | |||||
| diff = cross_entropy(prob) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||||
| @@ -0,0 +1,291 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for exponential distribution""" | |||||
| import numpy as np | |||||
| from scipy import stats | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Prob(nn.Cell): | |||||
| """ | |||||
| Test class: probability of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Prob, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('prob', x_) | |||||
| def test_pdf(): | |||||
| """ | |||||
| Test pdf. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_pdf = expon_benchmark.pdf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| pdf = Prob() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = pdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogProb, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('log_prob', x_) | |||||
| def test_log_likelihood(): | |||||
| """ | |||||
| Test log_pdf. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_logpdf = expon_benchmark.logpdf([0.5, 1.0, 2.0]).astype(np.float32) | |||||
| logprob = LogProb() | |||||
| x_ = Tensor(np.array([0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||||
| class KL(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Exponential distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(KL, self).__init__() | |||||
| self.e = nn.Exponential([1.5], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('kl_loss', 'Exponential', x_) | |||||
| def test_kl_loss(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | |||||
| rate_a = 1.5 | |||||
| rate_b = np.array([0.5, 2.0]).astype(np.float32) | |||||
| expect_kl_loss = np.log(rate_a) - np.log(rate_b) + rate_b / rate_a - 1.0 | |||||
| kl = KL() | |||||
| output = kl(Tensor(rate_b, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd/mode of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Basics, self).__init__() | |||||
| self.e = nn.Exponential([0.5], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.e('mean'), self.e('sd'), self.e('mode') | |||||
| def test_basics(): | |||||
| """ | |||||
| Test mean/standard/mode deviation. | |||||
| """ | |||||
| basics = Basics() | |||||
| mean, sd, mode = basics() | |||||
| expect_mean = 2. | |||||
| expect_sd = 2. | |||||
| expect_mode = 0. | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: sample of Exponential distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Sampling, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, rate=None): | |||||
| return self.e('sample', self.shape, rate) | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| seed = 10 | |||||
| rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) | |||||
| sample = Sampling(shape, seed=seed) | |||||
| output = sample(rate) | |||||
| assert output.shape == (2, 3, 3) | |||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('cdf', x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_cdf = expon_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| cdf = CDF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = cdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| class LogCDF(nn.Cell): | |||||
| """ | |||||
| Test class: log_cdf of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('log_cdf', x_) | |||||
| def test_log_cdf(): | |||||
| """ | |||||
| Test log_cdf. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_logcdf = expon_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) | |||||
| logcdf = LogCDF() | |||||
| x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logcdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| class SF(nn.Cell): | |||||
| """ | |||||
| Test class: survival function of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('survival_function', x_) | |||||
| def test_survival(): | |||||
| """ | |||||
| Test survival function. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_survival = expon_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| survival = SF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = survival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survival function of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.e('log_survival', x_) | |||||
| def test_log_survival(): | |||||
| """ | |||||
| Test log survival function. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_logsurvival = expon_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) | |||||
| logsurvival = LogSF() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logsurvival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() | |||||
| class EntropyH(nn.Cell): | |||||
| """ | |||||
| Test class: entropy of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(EntropyH, self).__init__() | |||||
| self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.e('entropy') | |||||
| def test_entropy(): | |||||
| """ | |||||
| Test entropy. | |||||
| """ | |||||
| expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) | |||||
| expect_entropy = expon_benchmark.entropy().astype(np.float32) | |||||
| entropy = EntropyH() | |||||
| output = entropy() | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||||
| class CrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross entropy between Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.e = nn.Exponential([1.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| entropy = self.e('entropy') | |||||
| kl_loss = self.e('kl_loss', 'Exponential', x_) | |||||
| h_sum_kl = entropy + kl_loss | |||||
| cross_entropy = self.e('cross_entropy', 'Exponential', x_) | |||||
| return h_sum_kl - cross_entropy | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy. | |||||
| """ | |||||
| cross_entropy = CrossEntropy() | |||||
| rate = Tensor([0.5], dtype=dtype.float32) | |||||
| diff = cross_entropy(rate) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||||
| @@ -0,0 +1,291 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for Geometric distribution""" | |||||
| import numpy as np | |||||
| from scipy import stats | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Prob(nn.Cell): | |||||
| """ | |||||
| Test class: probability of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Prob, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('prob', x_) | |||||
| def test_pmf(): | |||||
| """ | |||||
| Test pmf. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32) | |||||
| pdf = Prob() | |||||
| x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) | |||||
| output = pdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() | |||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogProb, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('log_prob', x_) | |||||
| def test_log_likelihood(): | |||||
| """ | |||||
| Test log_pmf. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32) | |||||
| logprob = LogProb() | |||||
| x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() | |||||
| class KL(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Geometric distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(KL, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('kl_loss', 'Geometric', x_) | |||||
| def test_kl_loss(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | |||||
| probs1_a = 0.7 | |||||
| probs1_b = 0.5 | |||||
| probs0_a = 1 - probs1_a | |||||
| probs0_b = 1 - probs1_b | |||||
| expect_kl_loss = np.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * np.log(probs0_a / probs0_b) | |||||
| kl_loss = KL() | |||||
| output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd/mode of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Basics, self).__init__() | |||||
| self.g = nn.Geometric([0.5, 0.5], dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.g('mean'), self.g('sd'), self.g('mode') | |||||
| def test_basics(): | |||||
| """ | |||||
| Test mean/standard deviation/mode. | |||||
| """ | |||||
| basics = Basics() | |||||
| mean, sd, mode = basics() | |||||
| expect_mean = [1.0, 1.0] | |||||
| expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5]))) | |||||
| expect_mode = [0.0, 0.0] | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy()- expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() | |||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of bernoulli distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Sampling, self).__init__() | |||||
| self.g = nn.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, probs=None): | |||||
| return self.g('sample', self.shape, probs) | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| sample = Sampling(shape) | |||||
| output = sample() | |||||
| assert output.shape == (2, 3, 2) | |||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('cdf', x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32) | |||||
| x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) | |||||
| cdf = CDF() | |||||
| output = cdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| class LogCDF(nn.Cell): | |||||
| """ | |||||
| Test class: log cdf of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('log_cdf', x_) | |||||
| def test_logcdf(): | |||||
| """ | |||||
| Test log_cdf. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32) | |||||
| x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) | |||||
| logcdf = LogCDF() | |||||
| output = logcdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| class SF(nn.Cell): | |||||
| """ | |||||
| Test class: survial funciton of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('survival_function', x_) | |||||
| def test_survival(): | |||||
| """ | |||||
| Test survival function. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32) | |||||
| x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) | |||||
| sf = SF() | |||||
| output = sf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survial funciton of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.g('log_survival', x_) | |||||
| def test_log_survival(): | |||||
| """ | |||||
| Test log_survival function. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_logsurvival = geom_benchmark.logsf([0, 1, 2, 3, 4]).astype(np.float32) | |||||
| x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) | |||||
| log_sf = LogSF() | |||||
| output = log_sf(x_) | |||||
| tol = 5e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() | |||||
| class EntropyH(nn.Cell): | |||||
| """ | |||||
| Test class: entropy of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(EntropyH, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.g('entropy') | |||||
| def test_entropy(): | |||||
| """ | |||||
| Test entropy. | |||||
| """ | |||||
| geom_benchmark = stats.geom(0.7) | |||||
| expect_entropy = geom_benchmark.entropy().astype(np.float32) | |||||
| entropy = EntropyH() | |||||
| output = entropy() | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||||
| class CrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross entropy between Geometric distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.g = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| entropy = self.g('entropy') | |||||
| kl_loss = self.g('kl_loss', 'Geometric', x_) | |||||
| h_sum_kl = entropy + kl_loss | |||||
| ans = self.g('cross_entropy', 'Geometric', x_) | |||||
| return h_sum_kl - ans | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy. | |||||
| """ | |||||
| cross_entropy = CrossEntropy() | |||||
| prob = Tensor([0.5], dtype=dtype.float32) | |||||
| diff = cross_entropy(prob) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||||
| @@ -23,89 +23,66 @@ from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| class Net(nn.Cell): | |||||
| class Prob(nn.Cell): | |||||
| """ | """ | ||||
| Test class: probability of normal distribution. | |||||
| Test class: probability of Normal distribution. | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | |||||
| super(Prob, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | ||||
| @ms_function | @ms_function | ||||
| def construct(self, x_): | def construct(self, x_): | ||||
| return self.n('prob', x_) | return self.n('prob', x_) | ||||
| class Net1(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('log_prob', x_) | |||||
| class Net2(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_, y_): | |||||
| return self.n('kl_loss', 'Normal', x_, y_) | |||||
| class Net3(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd of normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Net3, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.n('mean'), self.n('sd') | |||||
| class Net4(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd of normal distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Net4, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, mean=None, sd=None): | |||||
| return self.n('sample', self.shape, mean, sd) | |||||
| def test_pdf(): | def test_pdf(): | ||||
| """ | """ | ||||
| Test pdf. | Test pdf. | ||||
| """ | """ | ||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | ||||
| expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) | expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) | ||||
| pdf = Net() | |||||
| pdf = Prob() | |||||
| output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | ||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | ||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogProb, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('log_prob', x_) | |||||
| def test_log_likelihood(): | def test_log_likelihood(): | ||||
| """ | """ | ||||
| Test log_pdf. | Test log_pdf. | ||||
| """ | """ | ||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | ||||
| expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) | ||||
| logprob = Net1() | |||||
| logprob = LogProb() | |||||
| output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) | ||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | ||||
| class KL(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(KL, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_, y_): | |||||
| return self.n('kl_loss', 'Normal', x_, y_) | |||||
| def test_kl_loss(): | def test_kl_loss(): | ||||
| """ | """ | ||||
| Test kl_loss. | Test kl_loss. | ||||
| @@ -120,25 +97,51 @@ def test_kl_loss(): | |||||
| squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) | squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) | ||||
| expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale | expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale | ||||
| kl_loss = Net2() | |||||
| kl_loss = KL() | |||||
| mean = Tensor(mean_b, dtype=dtype.float32) | mean = Tensor(mean_b, dtype=dtype.float32) | ||||
| sd = Tensor(sd_b, dtype=dtype.float32) | sd = Tensor(sd_b, dtype=dtype.float32) | ||||
| output = kl_loss(mean, sd) | output = kl_loss(mean, sd) | ||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | ||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd/mode of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Basics, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.n('mean'), self.n('sd'), self.n('mode') | |||||
| def test_basics(): | def test_basics(): | ||||
| """ | """ | ||||
| Test mean/standard deviation. | |||||
| Test mean/standard deviation/mode. | |||||
| """ | """ | ||||
| basics = Net3() | |||||
| mean, sd = basics() | |||||
| basics = Basics() | |||||
| mean, sd, mode = basics() | |||||
| expect_mean = [3.0, 3.0] | expect_mean = [3.0, 3.0] | ||||
| expect_sd = [2.0, 4.0] | expect_sd = [2.0, 4.0] | ||||
| tol = 1e-6 | tol = 1e-6 | ||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | ||||
| assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | ||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: sample of Normal distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Sampling, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, mean=None, sd=None): | |||||
| return self.n('sample', self.shape, mean, sd) | |||||
| def test_sample(): | def test_sample(): | ||||
| """ | """ | ||||
| Test sample. | Test sample. | ||||
| @@ -147,6 +150,149 @@ def test_sample(): | |||||
| seed = 10 | seed = 10 | ||||
| mean = Tensor([2.0], dtype=dtype.float32) | mean = Tensor([2.0], dtype=dtype.float32) | ||||
| sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) | sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) | ||||
| sample = Net4(shape, seed=seed) | |||||
| sample = Sampling(shape, seed=seed) | |||||
| output = sample(mean, sd) | output = sample(mean, sd) | ||||
| assert output.shape == (2, 3, 3) | assert output.shape == (2, 3, 3) | ||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('cdf', x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | |||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||||
| expect_cdf = norm_benchmark.cdf([1.0, 2.0]).astype(np.float32) | |||||
| cdf = CDF() | |||||
| output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||||
| tol = 2e-5 | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| class LogCDF(nn.Cell): | |||||
| """ | |||||
| Test class: log_cdf of Mormal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('log_cdf', x_) | |||||
| def test_log_cdf(): | |||||
| """ | |||||
| Test log cdf. | |||||
| """ | |||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||||
| expect_logcdf = norm_benchmark.logcdf([1.0, 2.0]).astype(np.float32) | |||||
| logcdf = LogCDF() | |||||
| output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||||
| tol = 5e-5 | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| class SF(nn.Cell): | |||||
| """ | |||||
| Test class: survival function of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('survival_function', x_) | |||||
| def test_survival(): | |||||
| """ | |||||
| Test log_survival. | |||||
| """ | |||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||||
| expect_survival = norm_benchmark.sf([1.0, 2.0]).astype(np.float32) | |||||
| survival_function = SF() | |||||
| output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||||
| tol = 2e-5 | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survival function of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.n('log_survival', x_) | |||||
| def test_log_survival(): | |||||
| """ | |||||
| Test log_survival. | |||||
| """ | |||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||||
| expect_log_survival = norm_benchmark.logsf([1.0, 2.0]).astype(np.float32) | |||||
| log_survival = LogSF() | |||||
| output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) | |||||
| tol = 2e-5 | |||||
| assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() | |||||
| class EntropyH(nn.Cell): | |||||
| """ | |||||
| Test class: entropy of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(EntropyH, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.n('entropy') | |||||
| def test_entropy(): | |||||
| """ | |||||
| Test entropy. | |||||
| """ | |||||
| norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) | |||||
| expect_entropy = norm_benchmark.entropy().astype(np.float32) | |||||
| entropy = EntropyH() | |||||
| output = entropy() | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||||
| class CrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross entropy between Normal distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_, y_): | |||||
| entropy = self.n('entropy') | |||||
| kl_loss = self.n('kl_loss', 'Normal', x_, y_) | |||||
| h_sum_kl = entropy + kl_loss | |||||
| cross_entropy = self.n('cross_entropy', 'Normal', x_, y_) | |||||
| return h_sum_kl - cross_entropy | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy. | |||||
| """ | |||||
| cross_entropy = CrossEntropy() | |||||
| mean = Tensor([1.0], dtype=dtype.float32) | |||||
| sd = Tensor([1.0], dtype=dtype.float32) | |||||
| diff = cross_entropy(mean, sd) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||||
| @@ -0,0 +1,293 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """test cases for uniform distribution""" | |||||
| import numpy as np | |||||
| from scipy import stats | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore import dtype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Prob(nn.Cell): | |||||
| """ | |||||
| Test class: probability of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Prob, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('prob', x_) | |||||
| def test_pdf(): | |||||
| """ | |||||
| Test pdf. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) | |||||
| expect_pdf = uniform_benchmark.pdf([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32) | |||||
| pdf = Prob() | |||||
| x_ = Tensor(np.array([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = pdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() | |||||
| class LogProb(nn.Cell): | |||||
| """ | |||||
| Test class: log probability of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogProb, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('log_prob', x_) | |||||
| def test_log_likelihood(): | |||||
| """ | |||||
| Test log_pdf. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) | |||||
| expect_logpdf = uniform_benchmark.logpdf([0.5]).astype(np.float32) | |||||
| logprob = LogProb() | |||||
| x_ = Tensor(np.array([0.5]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logprob(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() | |||||
| class KL(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Uniform distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(KL, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_, y_): | |||||
| return self.u('kl_loss', 'Uniform', x_, y_) | |||||
| def test_kl_loss(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | |||||
| low_a = 0.0 | |||||
| high_a = 1.5 | |||||
| low_b = -1.0 | |||||
| high_b = 2.0 | |||||
| expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) | |||||
| kl = KL() | |||||
| output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() | |||||
| class Basics(nn.Cell): | |||||
| """ | |||||
| Test class: mean/sd of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Basics, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [3.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.u('mean'), self.u('sd') | |||||
| def test_basics(): | |||||
| """ | |||||
| Test mean/standard deviation. | |||||
| """ | |||||
| basics = Basics() | |||||
| mean, sd = basics() | |||||
| expect_mean = [1.5] | |||||
| expect_sd = np.sqrt([0.75]) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() | |||||
| assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() | |||||
| class Sampling(nn.Cell): | |||||
| """ | |||||
| Test class: sample of Uniform distribution. | |||||
| """ | |||||
| def __init__(self, shape, seed=0): | |||||
| super(Sampling, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) | |||||
| self.shape = shape | |||||
| @ms_function | |||||
| def construct(self, low=None, high=None): | |||||
| return self.u('sample', self.shape, low, high) | |||||
| def test_sample(): | |||||
| """ | |||||
| Test sample. | |||||
| """ | |||||
| shape = (2, 3) | |||||
| seed = 10 | |||||
| low = Tensor([1.0], dtype=dtype.float32) | |||||
| high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32) | |||||
| sample = Sampling(shape, seed=seed) | |||||
| output = sample(low, high) | |||||
| assert output.shape == (2, 3, 3) | |||||
| class CDF(nn.Cell): | |||||
| """ | |||||
| Test class: cdf of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CDF, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('cdf', x_) | |||||
| def test_cdf(): | |||||
| """ | |||||
| Test cdf. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [1.0]) | |||||
| expect_cdf = uniform_benchmark.cdf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) | |||||
| cdf = CDF() | |||||
| x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = cdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() | |||||
| class LogCDF(nn.Cell): | |||||
| """ | |||||
| Test class: log_cdf of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogCDF, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('log_cdf', x_) | |||||
| class SF(nn.Cell): | |||||
| """ | |||||
| Test class: survival function of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SF, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('survival_function', x_) | |||||
| class LogSF(nn.Cell): | |||||
| """ | |||||
| Test class: log survival function of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LogSF, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_): | |||||
| return self.u('log_survival', x_) | |||||
| class EntropyH(nn.Cell): | |||||
| """ | |||||
| Test class: entropy of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(EntropyH, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return self.u('entropy') | |||||
| def test_entropy(): | |||||
| """ | |||||
| Test entropy. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [1.0, 2.0]) | |||||
| expect_entropy = uniform_benchmark.entropy().astype(np.float32) | |||||
| entropy = EntropyH() | |||||
| output = entropy() | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() | |||||
| class CrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy between Uniform distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) | |||||
| @ms_function | |||||
| def construct(self, x_, y_): | |||||
| entropy = self.u('entropy') | |||||
| kl_loss = self.u('kl_loss', 'Uniform', x_, y_) | |||||
| h_sum_kl = entropy + kl_loss | |||||
| cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_) | |||||
| return h_sum_kl - cross_entropy | |||||
| def test_log_cdf(): | |||||
| """ | |||||
| Test log_cdf. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [1.0]) | |||||
| expect_logcdf = uniform_benchmark.logcdf([0.5, 0.8, 2.0]).astype(np.float32) | |||||
| logcdf = LogCDF() | |||||
| x_ = Tensor(np.array([0.5, 0.8, 2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logcdf(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() | |||||
| def test_survival(): | |||||
| """ | |||||
| Test survival function. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [1.0]) | |||||
| expect_survival = uniform_benchmark.sf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) | |||||
| survival = SF() | |||||
| x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = survival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_survival) < tol).all() | |||||
| def test_log_survival(): | |||||
| """ | |||||
| Test log survival function. | |||||
| """ | |||||
| uniform_benchmark = stats.uniform([0.0], [1.0]) | |||||
| expect_logsurvival = uniform_benchmark.logsf([0.5, 0.8, -2.0]).astype(np.float32) | |||||
| logsurvival = LogSF() | |||||
| x_ = Tensor(np.array([0.5, 0.8, -2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| output = logsurvival(x_) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy. | |||||
| """ | |||||
| cross_entropy = CrossEntropy() | |||||
| low_b = -1.0 | |||||
| high_b = 2.0 | |||||
| diff = cross_entropy(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() | |||||
| @@ -0,0 +1,165 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution.Bernoulli. | |||||
| """ | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_arguments(): | |||||
| """ | |||||
| Args passing during initialization. | |||||
| """ | |||||
| b = nn.Bernoulli() | |||||
| assert isinstance(b, nn.Distribution) | |||||
| b = nn.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||||
| assert isinstance(b, nn.Distribution) | |||||
| def test_prob(): | |||||
| """ | |||||
| Invalid probability. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Bernoulli([-0.1], dtype=dtype.int32) | |||||
| with pytest.raises(ValueError): | |||||
| nn.Bernoulli([1.1], dtype=dtype.int32) | |||||
| class BernoulliProb(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliProb, self).__init__() | |||||
| self.b = nn.Bernoulli(0.5, dtype=dtype.int32) | |||||
| def construct(self, value): | |||||
| prob = self.b('prob', value) | |||||
| log_prob = self.b('log_prob', value) | |||||
| cdf = self.b('cdf', value) | |||||
| log_cdf = self.b('log_cdf', value) | |||||
| sf = self.b('survival_function', value) | |||||
| log_sf = self.b('log_survival', value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_bernoulli_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = BernoulliProb() | |||||
| value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BernoulliProb1(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize without probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliProb1, self).__init__() | |||||
| self.b = nn.Bernoulli(dtype=dtype.int32) | |||||
| def construct(self, value, probs): | |||||
| prob = self.b('prob', value, probs) | |||||
| log_prob = self.b('log_prob', value, probs) | |||||
| cdf = self.b('cdf', value, probs) | |||||
| log_cdf = self.b('log_cdf', value, probs) | |||||
| sf = self.b('survival_function', value, probs) | |||||
| log_sf = self.b('log_survival', value, probs) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_bernoulli_prob1(): | |||||
| """ | |||||
| Test probability functions: passing value/probs through construct. | |||||
| """ | |||||
| net = BernoulliProb1() | |||||
| value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) | |||||
| probs = Tensor([0.5], dtype=dtype.float32) | |||||
| ans = net(value, probs) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BernoulliKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliKl, self).__init__() | |||||
| self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| self.b2 = nn.Bernoulli(dtype=dtype.int32) | |||||
| def construct(self, probs_b, probs_a): | |||||
| kl1 = self.b1('kl_loss', 'Bernoulli', probs_b) | |||||
| kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) | |||||
| return kl1 + kl2 | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| ber_net = BernoulliKl() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = ber_net(probs_b, probs_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BernoulliCrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy of Bernoulli distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliCrossEntropy, self).__init__() | |||||
| self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| self.b2 = nn.Bernoulli(dtype=dtype.int32) | |||||
| def construct(self, probs_b, probs_a): | |||||
| h1 = self.b1('cross_entropy', 'Bernoulli', probs_b) | |||||
| h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a) | |||||
| return h1 + h2 | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy between Bernoulli distributions. | |||||
| """ | |||||
| net = BernoulliCrossEntropy() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = net(probs_b, probs_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class BernoulliBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd/var/mode/entropy function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliBasics, self).__init__() | |||||
| self.b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||||
| def construct(self): | |||||
| mean = self.b('mean') | |||||
| sd = self.b('sd') | |||||
| var = self.b('var') | |||||
| mode = self.b('mode') | |||||
| entropy = self.b('entropy') | |||||
| return mean + sd + var + mode + entropy | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. | |||||
| """ | |||||
| net = BernoulliBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -0,0 +1,166 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution.Exponential. | |||||
| """ | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_arguments(): | |||||
| """ | |||||
| Args passing during initialization. | |||||
| """ | |||||
| e = nn.Exponential() | |||||
| assert isinstance(e, nn.Distribution) | |||||
| e = nn.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) | |||||
| assert isinstance(e, nn.Distribution) | |||||
| def test_rate(): | |||||
| """ | |||||
| Invalid rate. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Exponential([-0.1], dtype=dtype.float32) | |||||
| with pytest.raises(ValueError): | |||||
| nn.Exponential([0.0], dtype=dtype.float32) | |||||
| class ExponentialProb(nn.Cell): | |||||
| """ | |||||
| Exponential distribution: initialize with rate. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ExponentialProb, self).__init__() | |||||
| self.e = nn.Exponential(0.5, dtype=dtype.float32) | |||||
| def construct(self, value): | |||||
| prob = self.e('prob', value) | |||||
| log_prob = self.e('log_prob', value) | |||||
| cdf = self.e('cdf', value) | |||||
| log_cdf = self.e('log_cdf', value) | |||||
| sf = self.e('survival_function', value) | |||||
| log_sf = self.e('log_survival', value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_exponential_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = ExponentialProb() | |||||
| value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ExponentialProb1(nn.Cell): | |||||
| """ | |||||
| Exponential distribution: initialize without rate. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ExponentialProb1, self).__init__() | |||||
| self.e = nn.Exponential(dtype=dtype.float32) | |||||
| def construct(self, value, rate): | |||||
| prob = self.e('prob', value, rate) | |||||
| log_prob = self.e('log_prob', value, rate) | |||||
| cdf = self.e('cdf', value, rate) | |||||
| log_cdf = self.e('log_cdf', value, rate) | |||||
| sf = self.e('survival_function', value, rate) | |||||
| log_sf = self.e('log_survival', value, rate) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_exponential_prob1(): | |||||
| """ | |||||
| Test probability functions: passing value/rate through construct. | |||||
| """ | |||||
| net = ExponentialProb1() | |||||
| value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) | |||||
| rate = Tensor([0.5], dtype=dtype.float32) | |||||
| ans = net(value, rate) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ExponentialKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Exponential distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ExponentialKl, self).__init__() | |||||
| self.e1 = nn.Exponential(0.7, dtype=dtype.float32) | |||||
| self.e2 = nn.Exponential(dtype=dtype.float32) | |||||
| def construct(self, rate_b, rate_a): | |||||
| kl1 = self.e1('kl_loss', 'Exponential', rate_b) | |||||
| kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a) | |||||
| return kl1 + kl2 | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| net = ExponentialKl() | |||||
| rate_b = Tensor([0.3], dtype=dtype.float32) | |||||
| rate_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = net(rate_b, rate_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ExponentialCrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy of Exponential distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ExponentialCrossEntropy, self).__init__() | |||||
| self.e1 = nn.Exponential(0.3, dtype=dtype.float32) | |||||
| self.e2 = nn.Exponential(dtype=dtype.float32) | |||||
| def construct(self, rate_b, rate_a): | |||||
| h1 = self.e1('cross_entropy', 'Exponential', rate_b) | |||||
| h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a) | |||||
| return h1 + h2 | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy between Exponential distributions. | |||||
| """ | |||||
| net = ExponentialCrossEntropy() | |||||
| rate_b = Tensor([0.3], dtype=dtype.float32) | |||||
| rate_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = net(rate_b, rate_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class ExponentialBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd/mode/entropy function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(ExponentialBasics, self).__init__() | |||||
| self.e = nn.Exponential([0.3, 0.5], dtype=dtype.float32) | |||||
| def construct(self): | |||||
| mean = self.e('mean') | |||||
| sd = self.e('sd') | |||||
| var = self.e('var') | |||||
| mode = self.e('mode') | |||||
| entropy = self.e('entropy') | |||||
| return mean + sd + var + mode + entropy | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/var/mode/entropy functionality of Exponential distribution. | |||||
| """ | |||||
| net = ExponentialBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -0,0 +1,167 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution.Geometric. | |||||
| """ | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_arguments(): | |||||
| """ | |||||
| Args passing during initialization. | |||||
| """ | |||||
| g = nn.Geometric() | |||||
| assert isinstance(g, nn.Distribution) | |||||
| g = nn.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) | |||||
| assert isinstance(g, nn.Distribution) | |||||
| def test_prob(): | |||||
| """ | |||||
| Invalid probability. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Geometric([-0.1], dtype=dtype.int32) | |||||
| with pytest.raises(ValueError): | |||||
| nn.Geometric([1.1], dtype=dtype.int32) | |||||
| class GeometricProb(nn.Cell): | |||||
| """ | |||||
| Geometric distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GeometricProb, self).__init__() | |||||
| self.g = nn.Geometric(0.5, dtype=dtype.int32) | |||||
| def construct(self, value): | |||||
| prob = self.g('prob', value) | |||||
| log_prob = self.g('log_prob', value) | |||||
| cdf = self.g('cdf', value) | |||||
| log_cdf = self.g('log_cdf', value) | |||||
| sf = self.g('survival_function', value) | |||||
| log_sf = self.g('log_survival', value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_geometric_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = GeometricProb() | |||||
| value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class GeometricProb1(nn.Cell): | |||||
| """ | |||||
| Geometric distribution: initialize without probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GeometricProb1, self).__init__() | |||||
| self.g = nn.Geometric(dtype=dtype.int32) | |||||
| def construct(self, value, probs): | |||||
| prob = self.g('prob', value, probs) | |||||
| log_prob = self.g('log_prob', value, probs) | |||||
| cdf = self.g('cdf', value, probs) | |||||
| log_cdf = self.g('log_cdf', value, probs) | |||||
| sf = self.g('survival_function', value, probs) | |||||
| log_sf = self.g('log_survival', value, probs) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_geometric_prob1(): | |||||
| """ | |||||
| Test probability functions: passing value/probs through construct. | |||||
| """ | |||||
| net = GeometricProb1() | |||||
| value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) | |||||
| probs = Tensor([0.5], dtype=dtype.float32) | |||||
| ans = net(value, probs) | |||||
| assert isinstance(ans, Tensor) | |||||
| class GeometricKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Geometric distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GeometricKl, self).__init__() | |||||
| self.g1 = nn.Geometric(0.7, dtype=dtype.int32) | |||||
| self.g2 = nn.Geometric(dtype=dtype.int32) | |||||
| def construct(self, probs_b, probs_a): | |||||
| kl1 = self.g1('kl_loss', 'Geometric', probs_b) | |||||
| kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a) | |||||
| return kl1 + kl2 | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| ber_net = GeometricKl() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = ber_net(probs_b, probs_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class GeometricCrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy of Geometric distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GeometricCrossEntropy, self).__init__() | |||||
| self.g1 = nn.Geometric(0.3, dtype=dtype.int32) | |||||
| self.g2 = nn.Geometric(dtype=dtype.int32) | |||||
| def construct(self, probs_b, probs_a): | |||||
| h1 = self.g1('cross_entropy', 'Geometric', probs_b) | |||||
| h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a) | |||||
| return h1 + h2 | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy between Geometric distributions. | |||||
| """ | |||||
| net = GeometricCrossEntropy() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| ans = net(probs_b, probs_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class GeometricBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd/mode/entropy function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GeometricBasics, self).__init__() | |||||
| self.g = nn.Geometric([0.3, 0.5], dtype=dtype.int32) | |||||
| def construct(self): | |||||
| mean = self.g('mean') | |||||
| sd = self.g('sd') | |||||
| var = self.g('var') | |||||
| mode = self.g('mode') | |||||
| entropy = self.g('entropy') | |||||
| return mean + sd + var + mode + entropy | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/mode/entropy functionality of Geometric distribution. | |||||
| """ | |||||
| net = GeometricBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -0,0 +1,171 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution.Normal. | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_normal_shape_errpr(): | |||||
| """ | |||||
| Invalid shapes. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||||
| def test_arguments(): | |||||
| """ | |||||
| args passing during initialization. | |||||
| """ | |||||
| n = nn.Normal() | |||||
| assert isinstance(n, nn.Distribution) | |||||
| n = nn.Normal([3.0], [4.0], dtype=dtype.float32) | |||||
| assert isinstance(n, nn.Distribution) | |||||
| class NormalProb(nn.Cell): | |||||
| """ | |||||
| Normal distribution: initialize with mean/sd. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalProb, self).__init__() | |||||
| self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self, value): | |||||
| prob = self.normal('prob', value) | |||||
| log_prob = self.normal('log_prob', value) | |||||
| cdf = self.normal('cdf', value) | |||||
| log_cdf = self.normal('log_cdf', value) | |||||
| sf = self.normal('survival_function', value) | |||||
| log_sf = self.normal('log_survival', value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_normal_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = NormalProb() | |||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class NormalProb1(nn.Cell): | |||||
| """ | |||||
| Normal distribution: initialize without mean/sd. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalProb1, self).__init__() | |||||
| self.normal = nn.Normal() | |||||
| def construct(self, value, mean, sd): | |||||
| prob = self.normal('prob', value, mean, sd) | |||||
| log_prob = self.normal('log_prob', value, mean, sd) | |||||
| cdf = self.normal('cdf', value, mean, sd) | |||||
| log_cdf = self.normal('log_cdf', value, mean, sd) | |||||
| sf = self.normal('survival_function', value, mean, sd) | |||||
| log_sf = self.normal('log_survival', value, mean, sd) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_normal_prob1(): | |||||
| """ | |||||
| Test probability functions: passing mean/sd, value through construct. | |||||
| """ | |||||
| net = NormalProb1() | |||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||||
| mean = Tensor([0.0], dtype=dtype.float32) | |||||
| sd = Tensor([1.0], dtype=dtype.float32) | |||||
| ans = net(value, mean, sd) | |||||
| assert isinstance(ans, Tensor) | |||||
| class NormalKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalKl, self).__init__() | |||||
| self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| self.n2 = nn.Normal(dtype=dtype.float32) | |||||
| def construct(self, mean_b, sd_b, mean_a, sd_a): | |||||
| kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b) | |||||
| kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) | |||||
| return kl1 + kl2 | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | |||||
| net = NormalKl() | |||||
| mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||||
| ans = net(mean_b, sd_b, mean_a, sd_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class NormalCrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalCrossEntropy, self).__init__() | |||||
| self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| self.n2 = nn.Normal(dtype=dtype.float32) | |||||
| def construct(self, mean_b, sd_b, mean_a, sd_a): | |||||
| h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b) | |||||
| h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a) | |||||
| return h1 + h2 | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross entropy between Normal distributions. | |||||
| """ | |||||
| net = NormalCrossEntropy() | |||||
| mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) | |||||
| mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||||
| ans = net(mean_b, sd_b, mean_a, sd_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class NormalBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalBasics, self).__init__() | |||||
| self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self): | |||||
| mean = self.n('mean') | |||||
| sd = self.n('sd') | |||||
| mode = self.n('mode') | |||||
| entropy = self.n('entropy') | |||||
| return mean + sd + mode + entropy | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/mode/entropy functionality of Normal. | |||||
| """ | |||||
| net = NormalBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -0,0 +1,180 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution.Uniform. | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_uniform_shape_errpr(): | |||||
| """ | |||||
| Invalid shapes. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||||
| def test_arguments(): | |||||
| """ | |||||
| Args passing during initialization. | |||||
| """ | |||||
| u = nn.Uniform() | |||||
| assert isinstance(u, nn.Distribution) | |||||
| u = nn.Uniform([3.0], [4.0], dtype=dtype.float32) | |||||
| assert isinstance(u, nn.Distribution) | |||||
| def test_invalid_range(): | |||||
| """ | |||||
| Test range of uniform distribution. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Uniform(0.0, 0.0, dtype=dtype.float32) | |||||
| with pytest.raises(ValueError): | |||||
| nn.Uniform(1.0, 0.0, dtype=dtype.float32) | |||||
| class UniformProb(nn.Cell): | |||||
| """ | |||||
| Uniform distribution: initialize with low/high. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(UniformProb, self).__init__() | |||||
| self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self, value): | |||||
| prob = self.u('prob', value) | |||||
| log_prob = self.u('log_prob', value) | |||||
| cdf = self.u('cdf', value) | |||||
| log_cdf = self.u('log_cdf', value) | |||||
| sf = self.u('survival_function', value) | |||||
| log_sf = self.u('log_survival', value) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_uniform_prob(): | |||||
| """ | |||||
| Test probability functions: passing value through construct. | |||||
| """ | |||||
| net = UniformProb() | |||||
| value = Tensor([3.1, 3.2, 3.3, 3.4], dtype=dtype.float32) | |||||
| ans = net(value) | |||||
| assert isinstance(ans, Tensor) | |||||
| class UniformProb1(nn.Cell): | |||||
| """ | |||||
| Uniform distribution: initialize without low/high. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(UniformProb1, self).__init__() | |||||
| self.u = nn.Uniform(dtype=dtype.float32) | |||||
| def construct(self, value, low, high): | |||||
| prob = self.u('prob', value, low, high) | |||||
| log_prob = self.u('log_prob', value, low, high) | |||||
| cdf = self.u('cdf', value, low, high) | |||||
| log_cdf = self.u('log_cdf', value, low, high) | |||||
| sf = self.u('survival_function', value, low, high) | |||||
| log_sf = self.u('log_survival', value, low, high) | |||||
| return prob + log_prob + cdf + log_cdf + sf + log_sf | |||||
| def test_uniform_prob1(): | |||||
| """ | |||||
| Test probability functions: passing low/high, value through construct. | |||||
| """ | |||||
| net = UniformProb1() | |||||
| value = Tensor([0.1, 0.2, 0.3, 0.9], dtype=dtype.float32) | |||||
| low = Tensor([0.0], dtype=dtype.float32) | |||||
| high = Tensor([1.0], dtype=dtype.float32) | |||||
| ans = net(value, low, high) | |||||
| assert isinstance(ans, Tensor) | |||||
| class UniformKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(UniformKl, self).__init__() | |||||
| self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| self.u2 = nn.Uniform(dtype=dtype.float32) | |||||
| def construct(self, low_b, high_b, low_a, high_a): | |||||
| kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b) | |||||
| kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) | |||||
| return kl1 + kl2 | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss. | |||||
| """ | |||||
| net = UniformKl() | |||||
| low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) | |||||
| high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) | |||||
| low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||||
| ans = net(low_b, high_b, low_a, high_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class UniformCrossEntropy(nn.Cell): | |||||
| """ | |||||
| Test class: cross_entropy of Uniform distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(UniformCrossEntropy, self).__init__() | |||||
| self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) | |||||
| self.u2 = nn.Uniform(dtype=dtype.float32) | |||||
| def construct(self, low_b, high_b, low_a, high_a): | |||||
| h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b) | |||||
| h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a) | |||||
| return h1 + h2 | |||||
| def test_cross_entropy(): | |||||
| """ | |||||
| Test cross_entropy between Unifrom distributions. | |||||
| """ | |||||
| net = UniformCrossEntropy() | |||||
| low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) | |||||
| high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) | |||||
| low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) | |||||
| high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) | |||||
| ans = net(low_b, high_b, low_a, high_a) | |||||
| assert isinstance(ans, Tensor) | |||||
| class UniformBasics(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd/var/mode/entropy function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(UniformBasics, self).__init__() | |||||
| self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self): | |||||
| mean = self.u('mean') | |||||
| sd = self.u('sd') | |||||
| var = self.u('var') | |||||
| entropy = self.u('entropy') | |||||
| return mean + sd + var + entropy | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd/var/mode/entropy functionality of Uniform. | |||||
| """ | |||||
| net = UniformBasics() | |||||
| ans = net() | |||||
| assert isinstance(ans, Tensor) | |||||
| @@ -1,369 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test nn.Distribution. | |||||
| Including Normal Distribution and Bernoulli Distribution. | |||||
| """ | |||||
| import pytest | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| def test_normal_shape_errpr(): | |||||
| """ | |||||
| Invalid shapes. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) | |||||
| def test_no_arguments(): | |||||
| """ | |||||
| No args passed in during initialization. | |||||
| """ | |||||
| n = nn.Normal() | |||||
| assert isinstance(n, nn.Distribution) | |||||
| b = nn.Bernoulli() | |||||
| assert isinstance(b, nn.Distribution) | |||||
| def test_with_arguments(): | |||||
| """ | |||||
| Args passed in during initialization. | |||||
| """ | |||||
| n = nn.Normal([3.0], [4.0], dtype=dtype.float32) | |||||
| assert isinstance(n, nn.Distribution) | |||||
| b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) | |||||
| assert isinstance(b, nn.Distribution) | |||||
| class NormalProb(nn.Cell): | |||||
| """ | |||||
| Normal distribution: initialize with mean/sd. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalProb, self).__init__() | |||||
| self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self, value): | |||||
| x = self.normal('prob', value) | |||||
| y = self.normal('log_prob', value) | |||||
| return x, y | |||||
| def test_normal_prob(): | |||||
| """ | |||||
| Test pdf/log_pdf: passing value through construct. | |||||
| """ | |||||
| net = NormalProb() | |||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||||
| pdf, log_pdf = net(value) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class NormalProb1(nn.Cell): | |||||
| """ | |||||
| Normal distribution: initialize without mean/sd. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalProb1, self).__init__() | |||||
| self.normal = nn.Normal() | |||||
| def construct(self, value, mean, sd): | |||||
| x = self.normal('prob', value, mean, sd) | |||||
| y = self.normal('log_prob', value, mean, sd) | |||||
| return x, y | |||||
| def test_normal_prob1(): | |||||
| """ | |||||
| Test pdf/logpdf: passing mean/sd, value through construct. | |||||
| """ | |||||
| net = NormalProb1() | |||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||||
| mean = Tensor([0.0], dtype=dtype.float32) | |||||
| sd = Tensor([1.0], dtype=dtype.float32) | |||||
| pdf, log_pdf = net(value, mean, sd) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class NormalProb2(nn.Cell): | |||||
| """ | |||||
| Normal distribution: initialize with mean/sd. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalProb2, self).__init__() | |||||
| self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| def construct(self, value, mean, sd): | |||||
| x = self.normal('prob', value, mean, sd) | |||||
| y = self.normal('log_prob', value, mean, sd) | |||||
| return x, y | |||||
| def test_normal_prob2(): | |||||
| """ | |||||
| Test pdf/log_pdf: passing mean/sd through construct. | |||||
| Overwrite original mean/sd. | |||||
| """ | |||||
| net = NormalProb2() | |||||
| value = Tensor([0.5, 1.0], dtype=dtype.float32) | |||||
| mean = Tensor([0.0], dtype=dtype.float32) | |||||
| sd = Tensor([1.0], dtype=dtype.float32) | |||||
| pdf, log_pdf = net(value, mean, sd) | |||||
| assert isinstance(pdf, Tensor) | |||||
| assert isinstance(log_pdf, Tensor) | |||||
| class BernoulliProb(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliProb, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | |||||
| def construct(self, value): | |||||
| return self.bernoulli('prob', value) | |||||
| class BernoulliLogProb(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) | |||||
| def construct(self, value): | |||||
| return self.bernoulli('log_prob', value) | |||||
| def test_bernoulli_prob(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing value through construct. | |||||
| """ | |||||
| net = BernoulliProb() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| pmf = net(value) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing value through construct. | |||||
| """ | |||||
| net = BernoulliLogProb() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| log_pmf = net(value) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class BernoulliProb1(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize without probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliProb1, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli() | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('prob', value, probs) | |||||
| class BernoulliLogProb1(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize without probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb1, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli() | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('log_prob', value, probs) | |||||
| def test_bernoulli_prob1(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs through construct. | |||||
| """ | |||||
| net = BernoulliProb1() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| pmf = net(value, probs) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob1(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs through construct. | |||||
| """ | |||||
| net = BernoulliLogProb1() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| log_pmf = net(value, probs) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class BernoulliProb2(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliProb2, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5) | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('prob', value, probs) | |||||
| class BernoulliLogProb2(nn.Cell): | |||||
| """ | |||||
| Bernoulli distribution: initialize with probs. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliLogProb2, self).__init__() | |||||
| self.bernoulli = nn.Bernoulli(0.5) | |||||
| def construct(self, value, probs): | |||||
| return self.bernoulli('log_prob', value, probs) | |||||
| def test_bernoulli_prob2(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs/value through construct. | |||||
| Overwrite original probs. | |||||
| """ | |||||
| net = BernoulliProb2() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| pmf = net(value, probs) | |||||
| assert isinstance(pmf, Tensor) | |||||
| def test_bernoulli_log_prob2(): | |||||
| """ | |||||
| Test pmf/log_pmf: passing probs/value through construct. | |||||
| Overwrite original probs. | |||||
| """ | |||||
| net = BernoulliLogProb2() | |||||
| value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) | |||||
| probs = Tensor([0.3], dtype=dtype.float32) | |||||
| log_pmf = net(value, probs) | |||||
| assert isinstance(log_pmf, Tensor) | |||||
| class NormalKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Normal distribution. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalKl, self).__init__() | |||||
| self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32) | |||||
| def construct(self, x_, y_): | |||||
| return self.n('kl_loss', 'Normal', x_, y_) | |||||
| class BernoulliKl(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Bernoulli distributions. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliKl, self).__init__() | |||||
| self.b = nn.Bernoulli(0.7, dtype=dtype.int32) | |||||
| def construct(self, x_): | |||||
| return self.b('kl_loss', 'Bernoulli', x_) | |||||
| def test_kl(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| nor_net = NormalKl() | |||||
| mean_b = np.array([1.0]).astype(np.float32) | |||||
| sd_b = np.array([1.0]).astype(np.float32) | |||||
| mean = Tensor(mean_b, dtype=dtype.float32) | |||||
| sd = Tensor(sd_b, dtype=dtype.float32) | |||||
| loss = nor_net(mean, sd) | |||||
| assert isinstance(loss, Tensor) | |||||
| ber_net = BernoulliKl() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| loss = ber_net(probs_b) | |||||
| assert isinstance(loss, Tensor) | |||||
| class NormalKlNoArgs(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss of Normal distribution. | |||||
| No args during initialization. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalKlNoArgs, self).__init__() | |||||
| self.n = nn.Normal(dtype=dtype.float32) | |||||
| def construct(self, x_, y_, w_, v_): | |||||
| return self.n('kl_loss', 'Normal', x_, y_, w_, v_) | |||||
| class BernoulliKlNoArgs(nn.Cell): | |||||
| """ | |||||
| Test class: kl_loss between Bernoulli distributions. | |||||
| No args during initialization. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(BernoulliKlNoArgs, self).__init__() | |||||
| self.b = nn.Bernoulli(dtype=dtype.int32) | |||||
| def construct(self, x_, y_): | |||||
| return self.b('kl_loss', 'Bernoulli', x_, y_) | |||||
| def test_kl_no_args(): | |||||
| """ | |||||
| Test kl_loss function. | |||||
| """ | |||||
| nor_net = NormalKlNoArgs() | |||||
| mean_b = np.array([1.0]).astype(np.float32) | |||||
| sd_b = np.array([1.0]).astype(np.float32) | |||||
| mean_a = np.array([2.0]).astype(np.float32) | |||||
| sd_a = np.array([3.0]).astype(np.float32) | |||||
| mean_b = Tensor(mean_b, dtype=dtype.float32) | |||||
| sd_b = Tensor(sd_b, dtype=dtype.float32) | |||||
| mean_a = Tensor(mean_a, dtype=dtype.float32) | |||||
| sd_a = Tensor(sd_a, dtype=dtype.float32) | |||||
| loss = nor_net(mean_b, sd_b, mean_a, sd_a) | |||||
| assert isinstance(loss, Tensor) | |||||
| ber_net = BernoulliKlNoArgs() | |||||
| probs_b = Tensor([0.3], dtype=dtype.float32) | |||||
| probs_a = Tensor([0.7], dtype=dtype.float32) | |||||
| loss = ber_net(probs_b, probs_a) | |||||
| assert isinstance(loss, Tensor) | |||||
| class NormalBernoulli(nn.Cell): | |||||
| """ | |||||
| Test class: basic mean/sd function. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(NormalBernoulli, self).__init__() | |||||
| self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) | |||||
| self.b = nn.Bernoulli(0.5, dtype=dtype.int32) | |||||
| def construct(self): | |||||
| normal_mean = self.n('mean') | |||||
| normal_sd = self.n('sd') | |||||
| bernoulli_mean = self.b('mean') | |||||
| bernoulli_sd = self.b('sd') | |||||
| return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd | |||||
| def test_bascis(): | |||||
| """ | |||||
| Test mean/sd functionality of Normal and Bernoulli. | |||||
| """ | |||||
| net = NormalBernoulli() | |||||
| normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() | |||||
| assert isinstance(normal_mean, Tensor) | |||||
| assert isinstance(normal_sd, Tensor) | |||||
| assert isinstance(bernoulli_mean, Tensor) | |||||
| assert isinstance(bernoulli_sd, Tensor) | |||||