| @@ -110,6 +110,8 @@ def check_scalar_from_param(params): | |||
| Notes: String parameters are excluded. | |||
| """ | |||
| for value in params.values(): | |||
| if value is None: | |||
| continue | |||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||
| return params['distribution'].is_scalar_batch | |||
| if isinstance(value, Parameter): | |||
| @@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer): | |||
| return x | |||
| raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") | |||
| def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): | |||
| def set_param_type(args, hint_type): | |||
| """ | |||
| check if arg_a and arg_b have the same dtype. | |||
| Find the common type among arguments. | |||
| Args: | |||
| args (dict): dictionary of arguments, {'name':value}. | |||
| hint_type (mindspore.dtype): hint type to return. | |||
| Raises: | |||
| TypeError: if tensors in args are not the same dtype. | |||
| """ | |||
| if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): | |||
| if isinstance(arg_a, np.ndarray): | |||
| a_dtype = mstype.pytype_to_dtype(arg_a.dtype) | |||
| else: | |||
| a_dtype = arg_a.dtype | |||
| if isinstance(arg_b, np.ndarray): | |||
| b_dtype = mstype.pytype_to_dtype(arg_b.dtype) | |||
| else: | |||
| b_dtype = arg_b.dtype | |||
| if a_dtype != b_dtype: | |||
| raise TypeError(f"{name_a} and {name_b} should have the same dtype.") | |||
| int_type = mstype.int_type + mstype.uint_type | |||
| if a_dtype in int_type or a_dtype == mstype.float64: | |||
| return mstype.float32 | |||
| return a_dtype | |||
| return hint_type | |||
| common_dtype = None | |||
| for name, arg in args.items(): | |||
| if hasattr(arg, 'dtype'): | |||
| if isinstance(arg, np.ndarray): | |||
| cur_dtype = mstype.pytype_to_dtype(arg.dtype) | |||
| else: | |||
| cur_dtype = arg.dtype | |||
| if common_dtype is None: | |||
| common_dtype = cur_dtype | |||
| elif cur_dtype != common_dtype: | |||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | |||
| int_type = mstype.int_type + mstype.uint_type | |||
| if common_dtype in int_type or common_dtype == mstype.float64: | |||
| return mstype.float32 | |||
| return hint_type if common_dtype is None else common_dtype | |||
| @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type | |||
| from ._utils.custom_ops import exp_generic, log_generic, erf_generic | |||
| @@ -119,13 +119,16 @@ class Bernoulli(Distribution): | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||
| if probs is not None: | |||
| self._probs = cast_to_tensor(probs, mstype.float32) | |||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||
| check_prob(self.probs) | |||
| else: | |||
| self._probs = probs | |||
| self.default_parameters = [self.probs] | |||
| self.parameter_names = ['probs1'] | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| @@ -157,24 +160,12 @@ class Bernoulli(Distribution): | |||
| """ | |||
| return self._probs | |||
| def _check_param(self, probs1): | |||
| """ | |||
| Check availablity of distribution specific args `probs1`. | |||
| """ | |||
| if probs1 is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(probs1, 'probs1') | |||
| else: | |||
| probs1 = self.checktensor(probs1, 'probs1') | |||
| return self.cast(probs1, self.parameter_type) | |||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||
| def _mean(self, probs1=None): | |||
| r""" | |||
| .. math:: | |||
| MEAN(B) = probs1 | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| return probs1 | |||
| def _mode(self, probs1=None): | |||
| @@ -182,7 +173,7 @@ class Bernoulli(Distribution): | |||
| .. math:: | |||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| prob_type = self.dtypeop(probs1) | |||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | |||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | |||
| @@ -194,7 +185,7 @@ class Bernoulli(Distribution): | |||
| .. math:: | |||
| VAR(B) = probs1 * probs0 | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| return self.exp(self.log(probs0) + self.log(probs1)) | |||
| @@ -203,11 +194,11 @@ class Bernoulli(Distribution): | |||
| .. math:: | |||
| H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | |||
| def _cross_entropy(self, dist, probs1_b, probs1_a=None): | |||
| def _cross_entropy(self, dist, probs1_b, probs1=None): | |||
| """ | |||
| Evaluate cross_entropy between Bernoulli distributions. | |||
| @@ -217,7 +208,7 @@ class Bernoulli(Distribution): | |||
| probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. | |||
| """ | |||
| check_distribution_name(dist, 'Bernoulli') | |||
| return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | |||
| return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) | |||
| def _log_prob(self, value, probs1=None): | |||
| r""" | |||
| @@ -233,7 +224,7 @@ class Bernoulli(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | |||
| @@ -253,7 +244,7 @@ class Bernoulli(Distribution): | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| prob_type = self.dtypeop(probs1) | |||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | |||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | |||
| @@ -264,7 +255,7 @@ class Bernoulli(Distribution): | |||
| less_than_zero = self.select(comp_zero, zeros, probs0) | |||
| return self.select(comp_one, less_than_zero, ones) | |||
| def _kl_loss(self, dist, probs1_b, probs1_a=None): | |||
| def _kl_loss(self, dist, probs1_b, probs1=None): | |||
| r""" | |||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | |||
| @@ -280,7 +271,7 @@ class Bernoulli(Distribution): | |||
| check_distribution_name(dist, 'Bernoulli') | |||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | |||
| probs1_b = self.cast(probs1_b, self.parameter_type) | |||
| probs1_a = self._check_param(probs1_a) | |||
| probs1_a = self._check_param_type(probs1) | |||
| probs0_a = 1.0 - probs1_a | |||
| probs0_b = 1.0 - probs1_b | |||
| return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | |||
| @@ -297,7 +288,7 @@ class Bernoulli(Distribution): | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| origin_shape = shape + self.shape(probs1) | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| @@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common import get_seed | |||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device | |||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ | |||
| raise_none_error | |||
| from ._utils.utils import CheckTuple, CheckTensor | |||
| @@ -115,6 +116,51 @@ class Distribution(Cell): | |||
| def broadcast_shape(self): | |||
| return self._broadcast_shape | |||
| def _check_param_type(self, *args): | |||
| """ | |||
| Check the availability and validity of default parameters and dist_spec_args. | |||
| dist_spec_args passed in must be tensors. If default parameter of the distribution | |||
| is None, its parameter must be passed in through `args`. | |||
| """ | |||
| broadcast_shape = None | |||
| common_dtype = None | |||
| out = [] | |||
| for arg, name, default in zip(args, self.parameter_names, self.default_parameters): | |||
| # check if the argument is a Tensor | |||
| if arg is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(arg, name) | |||
| else: | |||
| arg = self.checktensor(arg, name) | |||
| else: | |||
| arg = default if default is not None else raise_none_error(name) | |||
| # broadcast if the number of args > 1 | |||
| if broadcast_shape is None: | |||
| broadcast_shape = self.shape(arg) | |||
| common_dtype = self.dtypeop(arg) | |||
| else: | |||
| ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| broadcast_shape = self.shape(arg + ones) | |||
| # check if the arguments have the same dtype | |||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) | |||
| self.sametypeshape(arg, dtype_tensor) | |||
| arg = self.cast(arg, self.parameter_type) | |||
| out.append(arg) | |||
| if len(out) == 1: | |||
| return out[0] | |||
| # broadcast all args to broadcast_shape | |||
| result = () | |||
| for arg in out: | |||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| result = result + (arg,) | |||
| return result | |||
| def _check_value(self, value, name): | |||
| """ | |||
| Check availability of `value` as a Tensor. | |||
| @@ -211,163 +257,203 @@ class Distribution(Cell): | |||
| if hasattr(self, '_cross_entropy'): | |||
| self._call_cross_entropy = self._cross_entropy | |||
| def log_prob(self, *args, **kwargs): | |||
| def log_prob(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the log probability(pdf or pmf) at the given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_log_prob(*args, **kwargs) | |||
| return self._call_log_prob(value, *args, **kwargs) | |||
| def _calc_prob_from_log_prob(self, *args, **kwargs): | |||
| def _calc_prob_from_log_prob(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate prob from log probability. | |||
| .. math:: | |||
| probability(x) = \exp(log_likehood(x)) | |||
| """ | |||
| return self.exp(self._log_prob(*args, **kwargs)) | |||
| return self.exp(self._log_prob(value, *args, **kwargs)) | |||
| def prob(self, *args, **kwargs): | |||
| def prob(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the probability (pdf or pmf) at given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_prob(*args, **kwargs) | |||
| return self._call_prob(value, *args, **kwargs) | |||
| def _calc_log_prob_from_prob(self, *args, **kwargs): | |||
| def _calc_log_prob_from_prob(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate log probability from probability. | |||
| .. math:: | |||
| log_prob(x) = \log(prob(x)) | |||
| """ | |||
| return self.log(self._prob(*args, **kwargs)) | |||
| return self.log(self._prob(value, *args, **kwargs)) | |||
| def cdf(self, *args, **kwargs): | |||
| def cdf(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the cdf at given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_cdf(*args, **kwargs) | |||
| return self._call_cdf(value, *args, **kwargs) | |||
| def _calc_cdf_from_log_cdf(self, *args, **kwargs): | |||
| def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate cdf from log_cdf. | |||
| .. math:: | |||
| cdf(x) = \exp(log_cdf(x)) | |||
| """ | |||
| return self.exp(self._log_cdf(*args, **kwargs)) | |||
| return self.exp(self._log_cdf(value, *args, **kwargs)) | |||
| def _calc_cdf_from_survival(self, *args, **kwargs): | |||
| def _calc_cdf_from_survival(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate cdf from survival function. | |||
| .. math:: | |||
| cdf(x) = 1 - (survival_function(x)) | |||
| """ | |||
| return 1.0 - self._survival_function(*args, **kwargs) | |||
| return 1.0 - self._survival_function(value, *args, **kwargs) | |||
| def _calc_cdf_from_log_survival(self, *args, **kwargs): | |||
| def _calc_cdf_from_log_survival(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate cdf from log survival function. | |||
| .. math:: | |||
| cdf(x) = 1 - (\exp(log_survival(x))) | |||
| """ | |||
| return 1.0 - self.exp(self._log_survival(*args, **kwargs)) | |||
| return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) | |||
| def log_cdf(self, *args, **kwargs): | |||
| def log_cdf(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the log cdf at given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_log_cdf(*args, **kwargs) | |||
| return self._call_log_cdf(value, *args, **kwargs) | |||
| def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): | |||
| def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate log cdf from cdf. | |||
| .. math:: | |||
| log_cdf(x) = \log(cdf(x)) | |||
| """ | |||
| return self.log(self._call_cdf(*args, **kwargs)) | |||
| return self.log(self._call_cdf(value, *args, **kwargs)) | |||
| def survival_function(self, *args, **kwargs): | |||
| def survival_function(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the survival function at given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_survival(*args, **kwargs) | |||
| return self._call_survival(value, *args, **kwargs) | |||
| def _calc_survival_from_call_cdf(self, *args, **kwargs): | |||
| def _calc_survival_from_call_cdf(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate survival function from cdf. | |||
| .. math:: | |||
| survival_function(x) = 1 - (cdf(x)) | |||
| """ | |||
| return 1.0 - self._call_cdf(*args, **kwargs) | |||
| return 1.0 - self._call_cdf(value, *args, **kwargs) | |||
| def _calc_survival_from_log_survival(self, *args, **kwargs): | |||
| def _calc_survival_from_log_survival(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate survival function from log survival function. | |||
| .. math:: | |||
| survival(x) = \exp(survival_function(x)) | |||
| """ | |||
| return self.exp(self._log_survival(*args, **kwargs)) | |||
| return self.exp(self._log_survival(value, *args, **kwargs)) | |||
| def log_survival(self, *args, **kwargs): | |||
| def log_survival(self, value, *args, **kwargs): | |||
| """ | |||
| Evaluate the log survival function at given value. | |||
| Args: | |||
| value (Tensor): value to be evaluated. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The arguments `args` must include `value`. | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||
| `args` or `kwargs`. | |||
| """ | |||
| return self._call_log_survival(*args, **kwargs) | |||
| return self._call_log_survival(value, *args, **kwargs) | |||
| def _calc_log_survival_from_call_survival(self, *args, **kwargs): | |||
| def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): | |||
| r""" | |||
| Evaluate log survival function from survival function. | |||
| .. math:: | |||
| log_survival(x) = \log(survival_function(x)) | |||
| """ | |||
| return self.log(self._call_survival(*args, **kwargs)) | |||
| return self.log(self._call_survival(value, *args, **kwargs)) | |||
| def kl_loss(self, *args, **kwargs): | |||
| def kl_loss(self, dist, *args, **kwargs): | |||
| """ | |||
| Evaluate the KL divergence, i.e. KL(a||b). | |||
| Args: | |||
| dist (str): type of the distribution. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include the type of the distribution, parameters of distribution b. | |||
| Parameters for distribution a are optional. | |||
| dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. | |||
| Passing in dist_spec_args of distribution a is optional. | |||
| """ | |||
| return self._kl_loss(*args, **kwargs) | |||
| return self._kl_loss(dist, *args, **kwargs) | |||
| def mean(self, *args, **kwargs): | |||
| """ | |||
| Evaluate the mean. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._mean(*args, **kwargs) | |||
| @@ -375,8 +461,13 @@ class Distribution(Cell): | |||
| """ | |||
| Evaluate the mode. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._mode(*args, **kwargs) | |||
| @@ -384,8 +475,13 @@ class Distribution(Cell): | |||
| """ | |||
| Evaluate the standard deviation. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._call_sd(*args, **kwargs) | |||
| @@ -393,8 +489,13 @@ class Distribution(Cell): | |||
| """ | |||
| Evaluate the variance. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._call_var(*args, **kwargs) | |||
| @@ -420,37 +521,52 @@ class Distribution(Cell): | |||
| """ | |||
| Evaluate the entropy. | |||
| Args: | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._entropy(*args, **kwargs) | |||
| def cross_entropy(self, *args, **kwargs): | |||
| def cross_entropy(self, dist, *args, **kwargs): | |||
| """ | |||
| Evaluate the cross_entropy between distribution a and b. | |||
| Args: | |||
| dist (str): type of the distribution. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| The argument `args` must include the type of the distribution, parameters of distribution b. | |||
| Parameters for distribution a are optional. | |||
| dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. | |||
| Passing in dist_spec_args of distribution a is optional. | |||
| """ | |||
| return self._call_cross_entropy(*args, **kwargs) | |||
| return self._call_cross_entropy(dist, *args, **kwargs) | |||
| def _calc_cross_entropy(self, *args, **kwargs): | |||
| def _calc_cross_entropy(self, dist, *args, **kwargs): | |||
| r""" | |||
| Evaluate cross_entropy from entropy and kl divergence. | |||
| .. math:: | |||
| H(X, Y) = H(X) + KL(X||Y) | |||
| """ | |||
| return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) | |||
| return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) | |||
| def sample(self, *args, **kwargs): | |||
| """ | |||
| Sampling function. | |||
| Args: | |||
| shape (tuple): shape of the sample. | |||
| *args (list): the list of positional arguments forwarded to subclasses. | |||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||
| Note: | |||
| Shape of the sample is default to (). | |||
| dist_spec_args are optional. | |||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||
| *args* or *kwargs*. | |||
| """ | |||
| return self._sample(*args, **kwargs) | |||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| class Exponential(Distribution): | |||
| @@ -121,15 +120,19 @@ class Exponential(Distribution): | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Exponential, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = dtype | |||
| self.parameter_type = set_param_type({'rate': rate}, self.dtype) | |||
| if rate is not None: | |||
| self._rate = cast_to_tensor(rate, self.parameter_type) | |||
| check_greater_zero(self._rate, "rate") | |||
| else: | |||
| self._rate = rate | |||
| self.default_parameters = [self.rate] | |||
| self.parameter_names = ['rate'] | |||
| self.minval = np.finfo(np.float).tiny | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| @@ -156,28 +159,16 @@ class Exponential(Distribution): | |||
| @property | |||
| def rate(self): | |||
| """ | |||
| Return rate of the distribution. | |||
| Return `rate` of the distribution. | |||
| """ | |||
| return self._rate | |||
| def _check_param(self, rate): | |||
| """ | |||
| Check availablity of distribution specific argument `rate`. | |||
| """ | |||
| if rate is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(rate, 'rate') | |||
| else: | |||
| rate = self.checktensor(rate, 'rate') | |||
| return self.cast(rate, self.parameter_type) | |||
| return self.rate if self.rate is not None else raise_none_error('rate') | |||
| def _mean(self, rate=None): | |||
| r""" | |||
| .. math:: | |||
| MEAN(EXP) = \frac{1.0}{\lambda}. | |||
| """ | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| return 1.0 / rate | |||
| def _mode(self, rate=None): | |||
| @@ -185,7 +176,7 @@ class Exponential(Distribution): | |||
| .. math:: | |||
| MODE(EXP) = 0. | |||
| """ | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| return self.fill(self.dtype, self.shape(rate), 0.) | |||
| def _sd(self, rate=None): | |||
| @@ -193,7 +184,7 @@ class Exponential(Distribution): | |||
| .. math:: | |||
| sd(EXP) = \frac{1.0}{\lambda}. | |||
| """ | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| return 1.0 / rate | |||
| def _entropy(self, rate=None): | |||
| @@ -201,7 +192,7 @@ class Exponential(Distribution): | |||
| .. math:: | |||
| H(Exp) = 1 - \log(\lambda). | |||
| """ | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| return 1.0 - self.log(rate) | |||
| def _cross_entropy(self, dist, rate_b, rate=None): | |||
| @@ -234,7 +225,7 @@ class Exponential(Distribution): | |||
| """ | |||
| value = self._check_value(value, "value") | |||
| value = self.cast(value, self.dtype) | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| prob = self.log(rate) - rate * value | |||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | |||
| neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) | |||
| @@ -257,7 +248,7 @@ class Exponential(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| cdf = 1.0 - self.exp(-1. * rate * value) | |||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| @@ -279,7 +270,7 @@ class Exponential(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| sf = -1. * rate * value | |||
| zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| @@ -297,7 +288,7 @@ class Exponential(Distribution): | |||
| check_distribution_name(dist, 'Exponential') | |||
| rate_b = self._check_value(rate_b, 'rate_b') | |||
| rate_b = self.cast(rate_b, self.parameter_type) | |||
| rate_a = self._check_param(rate) | |||
| rate_a = self._check_param_type(rate) | |||
| return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | |||
| def _sample(self, shape=(), rate=None): | |||
| @@ -312,7 +303,7 @@ class Exponential(Distribution): | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| rate = self._check_param(rate) | |||
| rate = self._check_param_type(rate) | |||
| origin_shape = shape + self.shape(rate) | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| @@ -19,7 +19,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| set_param_type | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -123,13 +123,16 @@ class Geometric(Distribution): | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Geometric, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||
| if probs is not None: | |||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||
| check_prob(self._probs) | |||
| else: | |||
| self._probs = probs | |||
| self.default_parameters = [self.probs] | |||
| self.parameter_names = ['probs1'] | |||
| self.minval = np.finfo(np.float).tiny | |||
| # ops needed for the class | |||
| @@ -164,24 +167,12 @@ class Geometric(Distribution): | |||
| """ | |||
| return self._probs | |||
| def _check_param(self, probs1): | |||
| """ | |||
| Check availablity of distribution specific args probs1. | |||
| """ | |||
| if probs1 is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(probs1, 'probs1') | |||
| else: | |||
| probs1 = self.checktensor(probs1, 'probs1') | |||
| return self.cast(probs1, self.parameter_type) | |||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||
| def _mean(self, probs1=None): | |||
| r""" | |||
| .. math:: | |||
| MEAN(Geo) = \fratc{1 - probs1}{probs1} | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| return (1. - probs1) / probs1 | |||
| def _mode(self, probs1=None): | |||
| @@ -189,7 +180,7 @@ class Geometric(Distribution): | |||
| .. math:: | |||
| MODE(Geo) = 0 | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | |||
| def _var(self, probs1=None): | |||
| @@ -197,7 +188,7 @@ class Geometric(Distribution): | |||
| .. math:: | |||
| VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| return (1.0 - probs1) / self.sq(probs1) | |||
| def _entropy(self, probs1=None): | |||
| @@ -205,7 +196,7 @@ class Geometric(Distribution): | |||
| .. math:: | |||
| H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | |||
| """ | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | |||
| @@ -236,7 +227,7 @@ class Geometric(Distribution): | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | |||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| @@ -258,7 +249,7 @@ class Geometric(Distribution): | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | |||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | |||
| @@ -280,7 +271,7 @@ class Geometric(Distribution): | |||
| check_distribution_name(dist, 'Geometric') | |||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | |||
| probs1_b = self.cast(probs1_b, self.parameter_type) | |||
| probs1_a = self._check_param(probs1) | |||
| probs1_a = self._check_param_type(probs1) | |||
| probs0_a = 1.0 - probs1_a | |||
| probs0_b = 1.0 - probs1_b | |||
| return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | |||
| @@ -297,7 +288,7 @@ class Geometric(Distribution): | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| probs1 = self._check_param(probs1) | |||
| probs1 = self._check_param_type(probs1) | |||
| origin_shape = shape + self.shape(probs1) | |||
| if origin_shape == (): | |||
| sample_shape = (1,) | |||
| @@ -19,7 +19,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||
| raise_none_error, common_dtype | |||
| set_param_type | |||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic | |||
| class Normal(Distribution): | |||
| @@ -127,14 +127,17 @@ class Normal(Distribution): | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Normal, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) | |||
| self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype) | |||
| if mean is not None and sd is not None: | |||
| self._mean_value = cast_to_tensor(mean, self.parameter_type) | |||
| self._sd_value = cast_to_tensor(sd, self.parameter_type) | |||
| check_greater_zero(self._sd_value, "Standard deviation") | |||
| else: | |||
| self._mean_value = mean | |||
| self._sd_value = sd | |||
| self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type) | |||
| self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type) | |||
| self.default_parameters = [self._mean_value, self._sd_value] | |||
| self.parameter_names = ['mean', 'sd'] | |||
| #ops needed for the class | |||
| self.exp = exp_generic | |||
| @@ -159,51 +162,25 @@ class Normal(Distribution): | |||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||
| return str_info | |||
| def _check_param(self, mean, sd): | |||
| """ | |||
| Check availablity of distribution specific args `mean` and `sd`. | |||
| """ | |||
| if mean is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(mean, 'mean') | |||
| else: | |||
| mean = self.checktensor(mean, 'mean') | |||
| else: | |||
| mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') | |||
| if sd is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(sd, 'sd') | |||
| else: | |||
| sd = self.checktensor(sd, 'sd') | |||
| else: | |||
| sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') | |||
| batch_shape = self.shape(mean + sd) | |||
| mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0) | |||
| sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0) | |||
| self.sametypeshape(mean, sd) | |||
| mean = self.cast(mean, self.parameter_type) | |||
| sd = self.cast(sd, self.parameter_type) | |||
| return mean, sd | |||
| def _mean(self, mean=None, sd=None): | |||
| """ | |||
| The mean of the distribution. | |||
| """ | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| return mean | |||
| def _mode(self, mean=None, sd=None): | |||
| """ | |||
| The mode of the distribution. | |||
| """ | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| return mean | |||
| def _sd(self, mean=None, sd=None): | |||
| """ | |||
| The standard deviation of the distribution. | |||
| """ | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| return sd | |||
| def _entropy(self, mean=None, sd=None): | |||
| @@ -213,7 +190,7 @@ class Normal(Distribution): | |||
| .. math:: | |||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | |||
| """ | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | |||
| def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): | |||
| @@ -244,7 +221,7 @@ class Normal(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | |||
| neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) | |||
| return unnormalized_log_prob + neg_normalization | |||
| @@ -263,7 +240,7 @@ class Normal(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| sqrt2 = self.sqrt(self.const(2.0)) | |||
| adjusted = (value - mean) / (sd * sqrt2) | |||
| return 0.5 * (1.0 + self.erf(adjusted)) | |||
| @@ -288,7 +265,7 @@ class Normal(Distribution): | |||
| sd_b = self._check_value(sd_b, 'sd_b') | |||
| mean_b = self.cast(mean_b, self.parameter_type) | |||
| sd_b = self.cast(sd_b, self.parameter_type) | |||
| mean_a, sd_a = self._check_param(mean, sd) | |||
| mean_a, sd_a = self._check_param_type(mean, sd) | |||
| diff_log_scale = self.log(sd_a) - self.log(sd_b) | |||
| squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | |||
| return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | |||
| @@ -306,7 +283,7 @@ class Normal(Distribution): | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| mean, sd = self._check_param(mean, sd) | |||
| mean, sd = self._check_param_type(mean, sd) | |||
| batch_shape = self.shape(mean + sd) | |||
| origin_shape = shape + batch_shape | |||
| if origin_shape == (): | |||
| @@ -18,7 +18,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | |||
| raise_none_error, common_dtype | |||
| set_param_type | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| class Uniform(Distribution): | |||
| @@ -126,14 +126,17 @@ class Uniform(Distribution): | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Uniform, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) | |||
| self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype) | |||
| if low is not None and high is not None: | |||
| self._low = cast_to_tensor(low, dtype) | |||
| self._high = cast_to_tensor(high, dtype) | |||
| self._low = cast_to_tensor(low, self.parameter_type) | |||
| self._high = cast_to_tensor(high, self.parameter_type) | |||
| check_greater(self.low, self.high, "low value", "high value") | |||
| else: | |||
| self._low = low | |||
| self._high = high | |||
| self._low = low if low is None else cast_to_tensor(low, self.parameter_type) | |||
| self._high = high if high is None else cast_to_tensor(high, self.parameter_type) | |||
| self.default_parameters = [self.low, self.high] | |||
| self.parameter_names = ['low', 'high'] | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| @@ -162,32 +165,6 @@ class Uniform(Distribution): | |||
| str_info = f'batch_shape = {self._broadcast_shape}' | |||
| return str_info | |||
| def _check_param(self, low, high): | |||
| """ | |||
| Check availablity of distribution specific args `low` and `high`. | |||
| """ | |||
| if low is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(low, 'low') | |||
| else: | |||
| low = self.checktensor(low, 'low') | |||
| else: | |||
| low = self.low if self.low is not None else raise_none_error('low') | |||
| if high is not None: | |||
| if self.context_mode == 0: | |||
| self.checktensor(high, 'high') | |||
| else: | |||
| high = self.checktensor(high, 'high') | |||
| else: | |||
| high = self.high if self.high is not None else raise_none_error('high') | |||
| batch_shape = self.shape(high - low) | |||
| high = high * self.fill(self.dtypeop(high), batch_shape, 1.0) | |||
| low = low * self.fill(self.dtypeop(low), batch_shape, 1.0) | |||
| self.sametypeshape(high, low) | |||
| low = self.cast(low, self.parameter_type) | |||
| high = self.cast(high, self.parameter_type) | |||
| return low, high | |||
| @property | |||
| def low(self): | |||
| """ | |||
| @@ -209,7 +186,7 @@ class Uniform(Distribution): | |||
| .. math:: | |||
| range(U) = high -low | |||
| """ | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| return high - low | |||
| def _mean(self, low=None, high=None): | |||
| @@ -217,7 +194,7 @@ class Uniform(Distribution): | |||
| .. math:: | |||
| MEAN(U) = \frac{low + high}{2}. | |||
| """ | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| return (low + high) / 2. | |||
| def _var(self, low=None, high=None): | |||
| @@ -225,7 +202,7 @@ class Uniform(Distribution): | |||
| .. math:: | |||
| VAR(U) = \frac{(high -low) ^ 2}{12}. | |||
| """ | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| return self.sq(high - low) / 12.0 | |||
| def _entropy(self, low=None, high=None): | |||
| @@ -233,7 +210,7 @@ class Uniform(Distribution): | |||
| .. math:: | |||
| H(U) = \log(high - low). | |||
| """ | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| return self.log(high - low) | |||
| def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): | |||
| @@ -266,7 +243,7 @@ class Uniform(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | |||
| prob = self.exp(neg_ones * self.log(high - low)) | |||
| broadcast_shape = self.shape(prob) | |||
| @@ -292,7 +269,7 @@ class Uniform(Distribution): | |||
| low_b = self.cast(low_b, self.parameter_type) | |||
| high_b = self._check_value(high_b, 'high_b') | |||
| high_b = self.cast(high_b, self.parameter_type) | |||
| low_a, high_a = self._check_param(low, high) | |||
| low_a, high_a = self._check_param_type(low, high) | |||
| kl = self.log(high_b - low_b) - self.log(high_a - low_a) | |||
| comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | |||
| return self.select(comp, kl, self.log(self.zeroslike(kl))) | |||
| @@ -313,7 +290,7 @@ class Uniform(Distribution): | |||
| """ | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| prob = (value - low) / (high - low) | |||
| broadcast_shape = self.shape(prob) | |||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | |||
| @@ -336,7 +313,7 @@ class Uniform(Distribution): | |||
| Tensor, shape is shape + batch_shape. | |||
| """ | |||
| shape = self.checktuple(shape, 'shape') | |||
| low, high = self._check_param(low, high) | |||
| low, high = self._check_param_type(low, high) | |||
| broadcast_shape = self.shape(low + high) | |||
| origin_shape = shape + broadcast_shape | |||
| if origin_shape == (): | |||
| @@ -0,0 +1,182 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Test util functions used in distribution classes. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore import context | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.nn.probability.distribution._utils.utils import set_param_type, \ | |||
| cast_to_tensor, CheckTuple, CheckTensor | |||
| def test_set_param_type(): | |||
| """ | |||
| Test set_param_type function. | |||
| """ | |||
| tensor_fp16 = Tensor(0.1, dtype=dtype.float16) | |||
| tensor_fp32 = Tensor(0.1, dtype=dtype.float32) | |||
| tensor_fp64 = Tensor(0.1, dtype=dtype.float64) | |||
| tensor_int32 = Tensor(0.1, dtype=dtype.int32) | |||
| array_fp32 = np.array(1.0).astype(np.float32) | |||
| array_fp64 = np.array(1.0).astype(np.float64) | |||
| array_int32 = np.array(1.0).astype(np.int32) | |||
| dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32} | |||
| dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64} | |||
| dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32} | |||
| dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32} | |||
| dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64} | |||
| dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32} | |||
| dict7 = {'a': 1.0} | |||
| dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0} | |||
| dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16} | |||
| dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64} | |||
| dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64} | |||
| ans1 = set_param_type(dict1, dtype.float16) | |||
| assert ans1 == dtype.float32 | |||
| with pytest.raises(TypeError): | |||
| set_param_type(dict2, dtype.float32) | |||
| ans3 = set_param_type(dict3, dtype.float16) | |||
| assert ans3 == dtype.float32 | |||
| ans4 = set_param_type(dict4, dtype.float16) | |||
| assert ans4 == dtype.float32 | |||
| with pytest.raises(TypeError): | |||
| set_param_type(dict5, dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| set_param_type(dict6, dtype.float32) | |||
| ans7 = set_param_type(dict7, dtype.float32) | |||
| assert ans7 == dtype.float32 | |||
| ans8 = set_param_type(dict8, dtype.float32) | |||
| assert ans8 == dtype.float32 | |||
| ans9 = set_param_type(dict9, dtype.float32) | |||
| assert ans9 == dtype.float16 | |||
| ans10 = set_param_type(dict10, dtype.float32) | |||
| assert ans10 == dtype.float32 | |||
| ans11 = set_param_type(dict11, dtype.float32) | |||
| assert ans11 == dtype.float32 | |||
| def test_cast_to_tensor(): | |||
| """ | |||
| Test cast_to_tensor. | |||
| """ | |||
| with pytest.raises(ValueError): | |||
| cast_to_tensor(None, dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| cast_to_tensor(True, dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| cast_to_tensor({'a': 1, 'b': 2}, dtype.float32) | |||
| with pytest.raises(TypeError): | |||
| cast_to_tensor('tensor', dtype.float32) | |||
| ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param')) | |||
| assert isinstance(ans1, Parameter) | |||
| ans2 = cast_to_tensor(np.array(1.0).astype(np.float32)) | |||
| assert isinstance(ans2, Tensor) | |||
| ans3 = cast_to_tensor([1.0, 2.0]) | |||
| assert isinstance(ans3, Tensor) | |||
| ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32) | |||
| assert isinstance(ans4, Tensor) | |||
| ans5 = cast_to_tensor(0.1, dtype.float32) | |||
| assert isinstance(ans5, Tensor) | |||
| ans6 = cast_to_tensor(1, dtype.float32) | |||
| assert isinstance(ans6, Tensor) | |||
| class Net(Cell): | |||
| """ | |||
| Test class: CheckTuple. | |||
| """ | |||
| def __init__(self, value): | |||
| super(Net, self).__init__() | |||
| self.checktuple = CheckTuple() | |||
| self.value = value | |||
| def construct(self, value=None): | |||
| if value is None: | |||
| return self.checktuple(self.value, 'input') | |||
| return self.checktuple(value, 'input') | |||
| def test_check_tuple(): | |||
| """ | |||
| Test CheckTuple. | |||
| """ | |||
| net1 = Net((1, 2, 3)) | |||
| ans1 = net1() | |||
| assert isinstance(ans1, tuple) | |||
| with pytest.raises(TypeError): | |||
| net2 = Net('tuple') | |||
| net2() | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net3 = Net((1, 2, 3)) | |||
| ans3 = net3() | |||
| assert isinstance(ans3, tuple) | |||
| with pytest.raises(TypeError): | |||
| net4 = Net('tuple') | |||
| net4() | |||
| class Net1(Cell): | |||
| """ | |||
| Test class: CheckTensor. | |||
| """ | |||
| def __init__(self, value): | |||
| super(Net1, self).__init__() | |||
| self.checktensor = CheckTensor() | |||
| self.value = value | |||
| self.context = context.get_context('mode') | |||
| def construct(self, value=None): | |||
| value = self.value if value is None else value | |||
| if self.context == 0: | |||
| self.checktensor(value, 'input') | |||
| return value | |||
| return self.checktensor(value, 'input') | |||
| def test_check_tensor(): | |||
| """ | |||
| Test CheckTensor. | |||
| """ | |||
| value = Tensor(0.1, dtype=dtype.float32) | |||
| net1 = Net1(value) | |||
| ans1 = net1() | |||
| assert isinstance(ans1, Tensor) | |||
| ans1 = net1(value) | |||
| assert isinstance(ans1, Tensor) | |||
| with pytest.raises(TypeError): | |||
| net2 = Net1('tuple') | |||
| net2() | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net3 = Net1(value) | |||
| ans3 = net3() | |||
| assert isinstance(ans3, Tensor) | |||
| ans3 = net3(value) | |||
| assert isinstance(ans3, Tensor) | |||
| with pytest.raises(TypeError): | |||
| net4 = Net1('tuple') | |||
| net4() | |||