| @@ -110,6 +110,8 @@ def check_scalar_from_param(params): | |||||
| Notes: String parameters are excluded. | Notes: String parameters are excluded. | ||||
| """ | """ | ||||
| for value in params.values(): | for value in params.values(): | ||||
| if value is None: | |||||
| continue | |||||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | ||||
| return params['distribution'].is_scalar_batch | return params['distribution'].is_scalar_batch | ||||
| if isinstance(value, Parameter): | if isinstance(value, Parameter): | ||||
| @@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") | raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") | ||||
| def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): | |||||
| def set_param_type(args, hint_type): | |||||
| """ | """ | ||||
| check if arg_a and arg_b have the same dtype. | |||||
| Find the common type among arguments. | |||||
| Args: | |||||
| args (dict): dictionary of arguments, {'name':value}. | |||||
| hint_type (mindspore.dtype): hint type to return. | |||||
| Raises: | |||||
| TypeError: if tensors in args are not the same dtype. | |||||
| """ | """ | ||||
| if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): | |||||
| if isinstance(arg_a, np.ndarray): | |||||
| a_dtype = mstype.pytype_to_dtype(arg_a.dtype) | |||||
| else: | |||||
| a_dtype = arg_a.dtype | |||||
| if isinstance(arg_b, np.ndarray): | |||||
| b_dtype = mstype.pytype_to_dtype(arg_b.dtype) | |||||
| else: | |||||
| b_dtype = arg_b.dtype | |||||
| if a_dtype != b_dtype: | |||||
| raise TypeError(f"{name_a} and {name_b} should have the same dtype.") | |||||
| int_type = mstype.int_type + mstype.uint_type | |||||
| if a_dtype in int_type or a_dtype == mstype.float64: | |||||
| return mstype.float32 | |||||
| return a_dtype | |||||
| return hint_type | |||||
| common_dtype = None | |||||
| for name, arg in args.items(): | |||||
| if hasattr(arg, 'dtype'): | |||||
| if isinstance(arg, np.ndarray): | |||||
| cur_dtype = mstype.pytype_to_dtype(arg.dtype) | |||||
| else: | |||||
| cur_dtype = arg.dtype | |||||
| if common_dtype is None: | |||||
| common_dtype = cur_dtype | |||||
| elif cur_dtype != common_dtype: | |||||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | |||||
| int_type = mstype.int_type + mstype.uint_type | |||||
| if common_dtype in int_type or common_dtype == mstype.float64: | |||||
| return mstype.float32 | |||||
| return hint_type if common_dtype is None else common_dtype | |||||
| @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | |||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type | |||||
| from ._utils.custom_ops import exp_generic, log_generic, erf_generic | from ._utils.custom_ops import exp_generic, log_generic, erf_generic | ||||
| @@ -119,13 +119,16 @@ class Bernoulli(Distribution): | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = mstype.float32 | |||||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||||
| if probs is not None: | if probs is not None: | ||||
| self._probs = cast_to_tensor(probs, mstype.float32) | |||||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||||
| check_prob(self.probs) | check_prob(self.probs) | ||||
| else: | else: | ||||
| self._probs = probs | self._probs = probs | ||||
| self.default_parameters = [self.probs] | |||||
| self.parameter_names = ['probs1'] | |||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| self.log = log_generic | self.log = log_generic | ||||
| @@ -157,24 +160,12 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _check_param(self, probs1): | |||||
| """ | |||||
| Check availablity of distribution specific args `probs1`. | |||||
| """ | |||||
| if probs1 is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(probs1, 'probs1') | |||||
| else: | |||||
| probs1 = self.checktensor(probs1, 'probs1') | |||||
| return self.cast(probs1, self.parameter_type) | |||||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(B) = probs1 | MEAN(B) = probs1 | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| return probs1 | return probs1 | ||||
| def _mode(self, probs1=None): | def _mode(self, probs1=None): | ||||
| @@ -182,7 +173,7 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | MODE(B) = 1 if probs1 > 0.5 else = 0 | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| prob_type = self.dtypeop(probs1) | prob_type = self.dtypeop(probs1) | ||||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | zeros = self.fill(prob_type, self.shape(probs1), 0.0) | ||||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | ones = self.fill(prob_type, self.shape(probs1), 1.0) | ||||
| @@ -194,7 +185,7 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(B) = probs1 * probs0 | VAR(B) = probs1 * probs0 | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.exp(self.log(probs0) + self.log(probs1)) | return self.exp(self.log(probs0) + self.log(probs1)) | ||||
| @@ -203,11 +194,11 @@ class Bernoulli(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) | ||||
| def _cross_entropy(self, dist, probs1_b, probs1_a=None): | |||||
| def _cross_entropy(self, dist, probs1_b, probs1=None): | |||||
| """ | """ | ||||
| Evaluate cross_entropy between Bernoulli distributions. | Evaluate cross_entropy between Bernoulli distributions. | ||||
| @@ -217,7 +208,7 @@ class Bernoulli(Distribution): | |||||
| probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. | probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. | ||||
| """ | """ | ||||
| check_distribution_name(dist, 'Bernoulli') | check_distribution_name(dist, 'Bernoulli') | ||||
| return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) | |||||
| return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) | |||||
| def _log_prob(self, value, probs1=None): | def _log_prob(self, value, probs1=None): | ||||
| r""" | r""" | ||||
| @@ -233,7 +224,7 @@ class Bernoulli(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | ||||
| @@ -253,7 +244,7 @@ class Bernoulli(Distribution): | |||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| prob_type = self.dtypeop(probs1) | prob_type = self.dtypeop(probs1) | ||||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | value = value * self.fill(prob_type, self.shape(probs1), 1.0) | ||||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | ||||
| @@ -264,7 +255,7 @@ class Bernoulli(Distribution): | |||||
| less_than_zero = self.select(comp_zero, zeros, probs0) | less_than_zero = self.select(comp_zero, zeros, probs0) | ||||
| return self.select(comp_one, less_than_zero, ones) | return self.select(comp_one, less_than_zero, ones) | ||||
| def _kl_loss(self, dist, probs1_b, probs1_a=None): | |||||
| def _kl_loss(self, dist, probs1_b, probs1=None): | |||||
| r""" | r""" | ||||
| Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). | ||||
| @@ -280,7 +271,7 @@ class Bernoulli(Distribution): | |||||
| check_distribution_name(dist, 'Bernoulli') | check_distribution_name(dist, 'Bernoulli') | ||||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | probs1_b = self._check_value(probs1_b, 'probs1_b') | ||||
| probs1_b = self.cast(probs1_b, self.parameter_type) | probs1_b = self.cast(probs1_b, self.parameter_type) | ||||
| probs1_a = self._check_param(probs1_a) | |||||
| probs1_a = self._check_param_type(probs1) | |||||
| probs0_a = 1.0 - probs1_a | probs0_a = 1.0 - probs1_a | ||||
| probs0_b = 1.0 - probs1_b | probs0_b = 1.0 - probs1_b | ||||
| return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) | ||||
| @@ -297,7 +288,7 @@ class Bernoulli(Distribution): | |||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| shape = self.checktuple(shape, 'shape') | shape = self.checktuple(shape, 'shape') | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| origin_shape = shape + self.shape(probs1) | origin_shape = shape + self.shape(probs1) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell | |||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from mindspore.common import get_seed | from mindspore.common import get_seed | ||||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device | |||||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ | |||||
| raise_none_error | |||||
| from ._utils.utils import CheckTuple, CheckTensor | from ._utils.utils import CheckTuple, CheckTensor | ||||
| @@ -115,6 +116,51 @@ class Distribution(Cell): | |||||
| def broadcast_shape(self): | def broadcast_shape(self): | ||||
| return self._broadcast_shape | return self._broadcast_shape | ||||
| def _check_param_type(self, *args): | |||||
| """ | |||||
| Check the availability and validity of default parameters and dist_spec_args. | |||||
| dist_spec_args passed in must be tensors. If default parameter of the distribution | |||||
| is None, its parameter must be passed in through `args`. | |||||
| """ | |||||
| broadcast_shape = None | |||||
| common_dtype = None | |||||
| out = [] | |||||
| for arg, name, default in zip(args, self.parameter_names, self.default_parameters): | |||||
| # check if the argument is a Tensor | |||||
| if arg is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(arg, name) | |||||
| else: | |||||
| arg = self.checktensor(arg, name) | |||||
| else: | |||||
| arg = default if default is not None else raise_none_error(name) | |||||
| # broadcast if the number of args > 1 | |||||
| if broadcast_shape is None: | |||||
| broadcast_shape = self.shape(arg) | |||||
| common_dtype = self.dtypeop(arg) | |||||
| else: | |||||
| ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| broadcast_shape = self.shape(arg + ones) | |||||
| # check if the arguments have the same dtype | |||||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) | |||||
| self.sametypeshape(arg, dtype_tensor) | |||||
| arg = self.cast(arg, self.parameter_type) | |||||
| out.append(arg) | |||||
| if len(out) == 1: | |||||
| return out[0] | |||||
| # broadcast all args to broadcast_shape | |||||
| result = () | |||||
| for arg in out: | |||||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| result = result + (arg,) | |||||
| return result | |||||
| def _check_value(self, value, name): | def _check_value(self, value, name): | ||||
| """ | """ | ||||
| Check availability of `value` as a Tensor. | Check availability of `value` as a Tensor. | ||||
| @@ -211,163 +257,203 @@ class Distribution(Cell): | |||||
| if hasattr(self, '_cross_entropy'): | if hasattr(self, '_cross_entropy'): | ||||
| self._call_cross_entropy = self._cross_entropy | self._call_cross_entropy = self._cross_entropy | ||||
| def log_prob(self, *args, **kwargs): | |||||
| def log_prob(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the log probability(pdf or pmf) at the given value. | Evaluate the log probability(pdf or pmf) at the given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_log_prob(*args, **kwargs) | |||||
| return self._call_log_prob(value, *args, **kwargs) | |||||
| def _calc_prob_from_log_prob(self, *args, **kwargs): | |||||
| def _calc_prob_from_log_prob(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate prob from log probability. | Evaluate prob from log probability. | ||||
| .. math:: | .. math:: | ||||
| probability(x) = \exp(log_likehood(x)) | probability(x) = \exp(log_likehood(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_prob(*args, **kwargs)) | |||||
| return self.exp(self._log_prob(value, *args, **kwargs)) | |||||
| def prob(self, *args, **kwargs): | |||||
| def prob(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the probability (pdf or pmf) at given value. | Evaluate the probability (pdf or pmf) at given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_prob(*args, **kwargs) | |||||
| return self._call_prob(value, *args, **kwargs) | |||||
| def _calc_log_prob_from_prob(self, *args, **kwargs): | |||||
| def _calc_log_prob_from_prob(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate log probability from probability. | Evaluate log probability from probability. | ||||
| .. math:: | .. math:: | ||||
| log_prob(x) = \log(prob(x)) | log_prob(x) = \log(prob(x)) | ||||
| """ | """ | ||||
| return self.log(self._prob(*args, **kwargs)) | |||||
| return self.log(self._prob(value, *args, **kwargs)) | |||||
| def cdf(self, *args, **kwargs): | |||||
| def cdf(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the cdf at given value. | Evaluate the cdf at given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_cdf(*args, **kwargs) | |||||
| return self._call_cdf(value, *args, **kwargs) | |||||
| def _calc_cdf_from_log_cdf(self, *args, **kwargs): | |||||
| def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate cdf from log_cdf. | Evaluate cdf from log_cdf. | ||||
| .. math:: | .. math:: | ||||
| cdf(x) = \exp(log_cdf(x)) | cdf(x) = \exp(log_cdf(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_cdf(*args, **kwargs)) | |||||
| return self.exp(self._log_cdf(value, *args, **kwargs)) | |||||
| def _calc_cdf_from_survival(self, *args, **kwargs): | |||||
| def _calc_cdf_from_survival(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate cdf from survival function. | Evaluate cdf from survival function. | ||||
| .. math:: | .. math:: | ||||
| cdf(x) = 1 - (survival_function(x)) | cdf(x) = 1 - (survival_function(x)) | ||||
| """ | """ | ||||
| return 1.0 - self._survival_function(*args, **kwargs) | |||||
| return 1.0 - self._survival_function(value, *args, **kwargs) | |||||
| def _calc_cdf_from_log_survival(self, *args, **kwargs): | |||||
| def _calc_cdf_from_log_survival(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate cdf from log survival function. | Evaluate cdf from log survival function. | ||||
| .. math:: | .. math:: | ||||
| cdf(x) = 1 - (\exp(log_survival(x))) | cdf(x) = 1 - (\exp(log_survival(x))) | ||||
| """ | """ | ||||
| return 1.0 - self.exp(self._log_survival(*args, **kwargs)) | |||||
| return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) | |||||
| def log_cdf(self, *args, **kwargs): | |||||
| def log_cdf(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the log cdf at given value. | Evaluate the log cdf at given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_log_cdf(*args, **kwargs) | |||||
| return self._call_log_cdf(value, *args, **kwargs) | |||||
| def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): | |||||
| def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate log cdf from cdf. | Evaluate log cdf from cdf. | ||||
| .. math:: | .. math:: | ||||
| log_cdf(x) = \log(cdf(x)) | log_cdf(x) = \log(cdf(x)) | ||||
| """ | """ | ||||
| return self.log(self._call_cdf(*args, **kwargs)) | |||||
| return self.log(self._call_cdf(value, *args, **kwargs)) | |||||
| def survival_function(self, *args, **kwargs): | |||||
| def survival_function(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the survival function at given value. | Evaluate the survival function at given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_survival(*args, **kwargs) | |||||
| return self._call_survival(value, *args, **kwargs) | |||||
| def _calc_survival_from_call_cdf(self, *args, **kwargs): | |||||
| def _calc_survival_from_call_cdf(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate survival function from cdf. | Evaluate survival function from cdf. | ||||
| .. math:: | .. math:: | ||||
| survival_function(x) = 1 - (cdf(x)) | survival_function(x) = 1 - (cdf(x)) | ||||
| """ | """ | ||||
| return 1.0 - self._call_cdf(*args, **kwargs) | |||||
| return 1.0 - self._call_cdf(value, *args, **kwargs) | |||||
| def _calc_survival_from_log_survival(self, *args, **kwargs): | |||||
| def _calc_survival_from_log_survival(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate survival function from log survival function. | Evaluate survival function from log survival function. | ||||
| .. math:: | .. math:: | ||||
| survival(x) = \exp(survival_function(x)) | survival(x) = \exp(survival_function(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_survival(*args, **kwargs)) | |||||
| return self.exp(self._log_survival(value, *args, **kwargs)) | |||||
| def log_survival(self, *args, **kwargs): | |||||
| def log_survival(self, value, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the log survival function at given value. | Evaluate the log survival function at given value. | ||||
| Args: | |||||
| value (Tensor): value to be evaluated. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The arguments `args` must include `value`. | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its dist_spec_args through | |||||
| `args` or `kwargs`. | |||||
| """ | """ | ||||
| return self._call_log_survival(*args, **kwargs) | |||||
| return self._call_log_survival(value, *args, **kwargs) | |||||
| def _calc_log_survival_from_call_survival(self, *args, **kwargs): | |||||
| def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate log survival function from survival function. | Evaluate log survival function from survival function. | ||||
| .. math:: | .. math:: | ||||
| log_survival(x) = \log(survival_function(x)) | log_survival(x) = \log(survival_function(x)) | ||||
| """ | """ | ||||
| return self.log(self._call_survival(*args, **kwargs)) | |||||
| return self.log(self._call_survival(value, *args, **kwargs)) | |||||
| def kl_loss(self, *args, **kwargs): | |||||
| def kl_loss(self, dist, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the KL divergence, i.e. KL(a||b). | Evaluate the KL divergence, i.e. KL(a||b). | ||||
| Args: | |||||
| dist (str): type of the distribution. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include the type of the distribution, parameters of distribution b. | |||||
| Parameters for distribution a are optional. | |||||
| dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. | |||||
| Passing in dist_spec_args of distribution a is optional. | |||||
| """ | """ | ||||
| return self._kl_loss(*args, **kwargs) | |||||
| return self._kl_loss(dist, *args, **kwargs) | |||||
| def mean(self, *args, **kwargs): | def mean(self, *args, **kwargs): | ||||
| """ | """ | ||||
| Evaluate the mean. | Evaluate the mean. | ||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._mean(*args, **kwargs) | return self._mean(*args, **kwargs) | ||||
| @@ -375,8 +461,13 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the mode. | Evaluate the mode. | ||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._mode(*args, **kwargs) | return self._mode(*args, **kwargs) | ||||
| @@ -384,8 +475,13 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the standard deviation. | Evaluate the standard deviation. | ||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._call_sd(*args, **kwargs) | return self._call_sd(*args, **kwargs) | ||||
| @@ -393,8 +489,13 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the variance. | Evaluate the variance. | ||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._call_var(*args, **kwargs) | return self._call_var(*args, **kwargs) | ||||
| @@ -420,37 +521,52 @@ class Distribution(Cell): | |||||
| """ | """ | ||||
| Evaluate the entropy. | Evaluate the entropy. | ||||
| Args: | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._entropy(*args, **kwargs) | return self._entropy(*args, **kwargs) | ||||
| def cross_entropy(self, *args, **kwargs): | |||||
| def cross_entropy(self, dist, *args, **kwargs): | |||||
| """ | """ | ||||
| Evaluate the cross_entropy between distribution a and b. | Evaluate the cross_entropy between distribution a and b. | ||||
| Args: | |||||
| dist (str): type of the distribution. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| The argument `args` must include the type of the distribution, parameters of distribution b. | |||||
| Parameters for distribution a are optional. | |||||
| dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. | |||||
| Passing in dist_spec_args of distribution a is optional. | |||||
| """ | """ | ||||
| return self._call_cross_entropy(*args, **kwargs) | |||||
| return self._call_cross_entropy(dist, *args, **kwargs) | |||||
| def _calc_cross_entropy(self, *args, **kwargs): | |||||
| def _calc_cross_entropy(self, dist, *args, **kwargs): | |||||
| r""" | r""" | ||||
| Evaluate cross_entropy from entropy and kl divergence. | Evaluate cross_entropy from entropy and kl divergence. | ||||
| .. math:: | .. math:: | ||||
| H(X, Y) = H(X) + KL(X||Y) | H(X, Y) = H(X) + KL(X||Y) | ||||
| """ | """ | ||||
| return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) | |||||
| return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) | |||||
| def sample(self, *args, **kwargs): | def sample(self, *args, **kwargs): | ||||
| """ | """ | ||||
| Sampling function. | Sampling function. | ||||
| Args: | |||||
| shape (tuple): shape of the sample. | |||||
| *args (list): the list of positional arguments forwarded to subclasses. | |||||
| **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. | |||||
| Note: | Note: | ||||
| Shape of the sample is default to (). | |||||
| dist_spec_args are optional. | |||||
| A distribution can be optionally passed to the function by passing its *dist_spec_args* through | |||||
| *args* or *kwargs*. | |||||
| """ | """ | ||||
| return self._sample(*args, **kwargs) | return self._sample(*args, **kwargs) | ||||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||||
| raise_none_error | |||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| class Exponential(Distribution): | class Exponential(Distribution): | ||||
| @@ -121,15 +120,19 @@ class Exponential(Distribution): | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Exponential, self).__init__(seed, dtype, name, param) | super(Exponential, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = dtype | |||||
| self.parameter_type = set_param_type({'rate': rate}, self.dtype) | |||||
| if rate is not None: | if rate is not None: | ||||
| self._rate = cast_to_tensor(rate, self.parameter_type) | self._rate = cast_to_tensor(rate, self.parameter_type) | ||||
| check_greater_zero(self._rate, "rate") | check_greater_zero(self._rate, "rate") | ||||
| else: | else: | ||||
| self._rate = rate | self._rate = rate | ||||
| self.default_parameters = [self.rate] | |||||
| self.parameter_names = ['rate'] | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| self.log = log_generic | self.log = log_generic | ||||
| @@ -156,28 +159,16 @@ class Exponential(Distribution): | |||||
| @property | @property | ||||
| def rate(self): | def rate(self): | ||||
| """ | """ | ||||
| Return rate of the distribution. | |||||
| Return `rate` of the distribution. | |||||
| """ | """ | ||||
| return self._rate | return self._rate | ||||
| def _check_param(self, rate): | |||||
| """ | |||||
| Check availablity of distribution specific argument `rate`. | |||||
| """ | |||||
| if rate is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(rate, 'rate') | |||||
| else: | |||||
| rate = self.checktensor(rate, 'rate') | |||||
| return self.cast(rate, self.parameter_type) | |||||
| return self.rate if self.rate is not None else raise_none_error('rate') | |||||
| def _mean(self, rate=None): | def _mean(self, rate=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(EXP) = \frac{1.0}{\lambda}. | MEAN(EXP) = \frac{1.0}{\lambda}. | ||||
| """ | """ | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| return 1.0 / rate | return 1.0 / rate | ||||
| def _mode(self, rate=None): | def _mode(self, rate=None): | ||||
| @@ -185,7 +176,7 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(EXP) = 0. | MODE(EXP) = 0. | ||||
| """ | """ | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| return self.fill(self.dtype, self.shape(rate), 0.) | return self.fill(self.dtype, self.shape(rate), 0.) | ||||
| def _sd(self, rate=None): | def _sd(self, rate=None): | ||||
| @@ -193,7 +184,7 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| sd(EXP) = \frac{1.0}{\lambda}. | sd(EXP) = \frac{1.0}{\lambda}. | ||||
| """ | """ | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| return 1.0 / rate | return 1.0 / rate | ||||
| def _entropy(self, rate=None): | def _entropy(self, rate=None): | ||||
| @@ -201,7 +192,7 @@ class Exponential(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(Exp) = 1 - \log(\lambda). | H(Exp) = 1 - \log(\lambda). | ||||
| """ | """ | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| return 1.0 - self.log(rate) | return 1.0 - self.log(rate) | ||||
| def _cross_entropy(self, dist, rate_b, rate=None): | def _cross_entropy(self, dist, rate_b, rate=None): | ||||
| @@ -234,7 +225,7 @@ class Exponential(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, "value") | value = self._check_value(value, "value") | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| prob = self.log(rate) - rate * value | prob = self.log(rate) - rate * value | ||||
| zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) | ||||
| neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) | neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) | ||||
| @@ -257,7 +248,7 @@ class Exponential(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| cdf = 1.0 - self.exp(-1. * rate * value) | cdf = 1.0 - self.exp(-1. * rate * value) | ||||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| @@ -279,7 +270,7 @@ class Exponential(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| sf = -1. * rate * value | sf = -1. * rate * value | ||||
| zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) | zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| @@ -297,7 +288,7 @@ class Exponential(Distribution): | |||||
| check_distribution_name(dist, 'Exponential') | check_distribution_name(dist, 'Exponential') | ||||
| rate_b = self._check_value(rate_b, 'rate_b') | rate_b = self._check_value(rate_b, 'rate_b') | ||||
| rate_b = self.cast(rate_b, self.parameter_type) | rate_b = self.cast(rate_b, self.parameter_type) | ||||
| rate_a = self._check_param(rate) | |||||
| rate_a = self._check_param_type(rate) | |||||
| return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | ||||
| def _sample(self, shape=(), rate=None): | def _sample(self, shape=(), rate=None): | ||||
| @@ -312,7 +303,7 @@ class Exponential(Distribution): | |||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| shape = self.checktuple(shape, 'shape') | shape = self.checktuple(shape, 'shape') | ||||
| rate = self._check_param(rate) | |||||
| rate = self._check_param_type(rate) | |||||
| origin_shape = shape + self.shape(rate) | origin_shape = shape + self.shape(rate) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | ||||
| raise_none_error | |||||
| set_param_type | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -123,13 +123,16 @@ class Geometric(Distribution): | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Geometric, self).__init__(seed, dtype, name, param) | super(Geometric, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = mstype.float32 | |||||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||||
| if probs is not None: | if probs is not None: | ||||
| self._probs = cast_to_tensor(probs, self.parameter_type) | self._probs = cast_to_tensor(probs, self.parameter_type) | ||||
| check_prob(self._probs) | check_prob(self._probs) | ||||
| else: | else: | ||||
| self._probs = probs | self._probs = probs | ||||
| self.default_parameters = [self.probs] | |||||
| self.parameter_names = ['probs1'] | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | # ops needed for the class | ||||
| @@ -164,24 +167,12 @@ class Geometric(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _check_param(self, probs1): | |||||
| """ | |||||
| Check availablity of distribution specific args probs1. | |||||
| """ | |||||
| if probs1 is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(probs1, 'probs1') | |||||
| else: | |||||
| probs1 = self.checktensor(probs1, 'probs1') | |||||
| return self.cast(probs1, self.parameter_type) | |||||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||||
| def _mean(self, probs1=None): | def _mean(self, probs1=None): | ||||
| r""" | r""" | ||||
| .. math:: | .. math:: | ||||
| MEAN(Geo) = \fratc{1 - probs1}{probs1} | MEAN(Geo) = \fratc{1 - probs1}{probs1} | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| return (1. - probs1) / probs1 | return (1. - probs1) / probs1 | ||||
| def _mode(self, probs1=None): | def _mode(self, probs1=None): | ||||
| @@ -189,7 +180,7 @@ class Geometric(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MODE(Geo) = 0 | MODE(Geo) = 0 | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | ||||
| def _var(self, probs1=None): | def _var(self, probs1=None): | ||||
| @@ -197,7 +188,7 @@ class Geometric(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} | VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| return (1.0 - probs1) / self.sq(probs1) | return (1.0 - probs1) / self.sq(probs1) | ||||
| def _entropy(self, probs1=None): | def _entropy(self, probs1=None): | ||||
| @@ -205,7 +196,7 @@ class Geometric(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} | ||||
| """ | """ | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 | ||||
| @@ -236,7 +227,7 @@ class Geometric(Distribution): | |||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | ||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| @@ -258,7 +249,7 @@ class Geometric(Distribution): | |||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | value = self.cast(value, mstype.float32) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | cdf = 1.0 - self.pow(probs0, value + 1.0) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | ||||
| @@ -280,7 +271,7 @@ class Geometric(Distribution): | |||||
| check_distribution_name(dist, 'Geometric') | check_distribution_name(dist, 'Geometric') | ||||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | probs1_b = self._check_value(probs1_b, 'probs1_b') | ||||
| probs1_b = self.cast(probs1_b, self.parameter_type) | probs1_b = self.cast(probs1_b, self.parameter_type) | ||||
| probs1_a = self._check_param(probs1) | |||||
| probs1_a = self._check_param_type(probs1) | |||||
| probs0_a = 1.0 - probs1_a | probs0_a = 1.0 - probs1_a | ||||
| probs0_b = 1.0 - probs1_b | probs0_b = 1.0 - probs1_b | ||||
| return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) | ||||
| @@ -297,7 +288,7 @@ class Geometric(Distribution): | |||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| shape = self.checktuple(shape, 'shape') | shape = self.checktuple(shape, 'shape') | ||||
| probs1 = self._check_param(probs1) | |||||
| probs1 = self._check_param_type(probs1) | |||||
| origin_shape = shape + self.shape(probs1) | origin_shape = shape + self.shape(probs1) | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| sample_shape = (1,) | sample_shape = (1,) | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | ||||
| raise_none_error, common_dtype | |||||
| set_param_type | |||||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic | from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic | ||||
| class Normal(Distribution): | class Normal(Distribution): | ||||
| @@ -127,14 +127,17 @@ class Normal(Distribution): | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Normal, self).__init__(seed, dtype, name, param) | super(Normal, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) | |||||
| self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype) | |||||
| if mean is not None and sd is not None: | if mean is not None and sd is not None: | ||||
| self._mean_value = cast_to_tensor(mean, self.parameter_type) | self._mean_value = cast_to_tensor(mean, self.parameter_type) | ||||
| self._sd_value = cast_to_tensor(sd, self.parameter_type) | self._sd_value = cast_to_tensor(sd, self.parameter_type) | ||||
| check_greater_zero(self._sd_value, "Standard deviation") | check_greater_zero(self._sd_value, "Standard deviation") | ||||
| else: | else: | ||||
| self._mean_value = mean | |||||
| self._sd_value = sd | |||||
| self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type) | |||||
| self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type) | |||||
| self.default_parameters = [self._mean_value, self._sd_value] | |||||
| self.parameter_names = ['mean', 'sd'] | |||||
| #ops needed for the class | #ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -159,51 +162,25 @@ class Normal(Distribution): | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | str_info = f'batch_shape = {self._broadcast_shape}' | ||||
| return str_info | return str_info | ||||
| def _check_param(self, mean, sd): | |||||
| """ | |||||
| Check availablity of distribution specific args `mean` and `sd`. | |||||
| """ | |||||
| if mean is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(mean, 'mean') | |||||
| else: | |||||
| mean = self.checktensor(mean, 'mean') | |||||
| else: | |||||
| mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') | |||||
| if sd is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(sd, 'sd') | |||||
| else: | |||||
| sd = self.checktensor(sd, 'sd') | |||||
| else: | |||||
| sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') | |||||
| batch_shape = self.shape(mean + sd) | |||||
| mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0) | |||||
| sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0) | |||||
| self.sametypeshape(mean, sd) | |||||
| mean = self.cast(mean, self.parameter_type) | |||||
| sd = self.cast(sd, self.parameter_type) | |||||
| return mean, sd | |||||
| def _mean(self, mean=None, sd=None): | def _mean(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| The mean of the distribution. | The mean of the distribution. | ||||
| """ | """ | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| return mean | return mean | ||||
| def _mode(self, mean=None, sd=None): | def _mode(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| The mode of the distribution. | The mode of the distribution. | ||||
| """ | """ | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| return mean | return mean | ||||
| def _sd(self, mean=None, sd=None): | def _sd(self, mean=None, sd=None): | ||||
| """ | """ | ||||
| The standard deviation of the distribution. | The standard deviation of the distribution. | ||||
| """ | """ | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| return sd | return sd | ||||
| def _entropy(self, mean=None, sd=None): | def _entropy(self, mean=None, sd=None): | ||||
| @@ -213,7 +190,7 @@ class Normal(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) | ||||
| """ | """ | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) | ||||
| def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): | def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): | ||||
| @@ -244,7 +221,7 @@ class Normal(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | ||||
| neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) | neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) | ||||
| return unnormalized_log_prob + neg_normalization | return unnormalized_log_prob + neg_normalization | ||||
| @@ -263,7 +240,7 @@ class Normal(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| sqrt2 = self.sqrt(self.const(2.0)) | sqrt2 = self.sqrt(self.const(2.0)) | ||||
| adjusted = (value - mean) / (sd * sqrt2) | adjusted = (value - mean) / (sd * sqrt2) | ||||
| return 0.5 * (1.0 + self.erf(adjusted)) | return 0.5 * (1.0 + self.erf(adjusted)) | ||||
| @@ -288,7 +265,7 @@ class Normal(Distribution): | |||||
| sd_b = self._check_value(sd_b, 'sd_b') | sd_b = self._check_value(sd_b, 'sd_b') | ||||
| mean_b = self.cast(mean_b, self.parameter_type) | mean_b = self.cast(mean_b, self.parameter_type) | ||||
| sd_b = self.cast(sd_b, self.parameter_type) | sd_b = self.cast(sd_b, self.parameter_type) | ||||
| mean_a, sd_a = self._check_param(mean, sd) | |||||
| mean_a, sd_a = self._check_param_type(mean, sd) | |||||
| diff_log_scale = self.log(sd_a) - self.log(sd_b) | diff_log_scale = self.log(sd_a) - self.log(sd_b) | ||||
| squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) | ||||
| return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale | ||||
| @@ -306,7 +283,7 @@ class Normal(Distribution): | |||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| shape = self.checktuple(shape, 'shape') | shape = self.checktuple(shape, 'shape') | ||||
| mean, sd = self._check_param(mean, sd) | |||||
| mean, sd = self._check_param_type(mean, sd) | |||||
| batch_shape = self.shape(mean + sd) | batch_shape = self.shape(mean + sd) | ||||
| origin_shape = shape + batch_shape | origin_shape = shape + batch_shape | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| @@ -18,7 +18,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | ||||
| raise_none_error, common_dtype | |||||
| set_param_type | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| class Uniform(Distribution): | class Uniform(Distribution): | ||||
| @@ -126,14 +126,17 @@ class Uniform(Distribution): | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Uniform, self).__init__(seed, dtype, name, param) | super(Uniform, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) | |||||
| self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype) | |||||
| if low is not None and high is not None: | if low is not None and high is not None: | ||||
| self._low = cast_to_tensor(low, dtype) | |||||
| self._high = cast_to_tensor(high, dtype) | |||||
| self._low = cast_to_tensor(low, self.parameter_type) | |||||
| self._high = cast_to_tensor(high, self.parameter_type) | |||||
| check_greater(self.low, self.high, "low value", "high value") | check_greater(self.low, self.high, "low value", "high value") | ||||
| else: | else: | ||||
| self._low = low | |||||
| self._high = high | |||||
| self._low = low if low is None else cast_to_tensor(low, self.parameter_type) | |||||
| self._high = high if high is None else cast_to_tensor(high, self.parameter_type) | |||||
| self.default_parameters = [self.low, self.high] | |||||
| self.parameter_names = ['low', 'high'] | |||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -162,32 +165,6 @@ class Uniform(Distribution): | |||||
| str_info = f'batch_shape = {self._broadcast_shape}' | str_info = f'batch_shape = {self._broadcast_shape}' | ||||
| return str_info | return str_info | ||||
| def _check_param(self, low, high): | |||||
| """ | |||||
| Check availablity of distribution specific args `low` and `high`. | |||||
| """ | |||||
| if low is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(low, 'low') | |||||
| else: | |||||
| low = self.checktensor(low, 'low') | |||||
| else: | |||||
| low = self.low if self.low is not None else raise_none_error('low') | |||||
| if high is not None: | |||||
| if self.context_mode == 0: | |||||
| self.checktensor(high, 'high') | |||||
| else: | |||||
| high = self.checktensor(high, 'high') | |||||
| else: | |||||
| high = self.high if self.high is not None else raise_none_error('high') | |||||
| batch_shape = self.shape(high - low) | |||||
| high = high * self.fill(self.dtypeop(high), batch_shape, 1.0) | |||||
| low = low * self.fill(self.dtypeop(low), batch_shape, 1.0) | |||||
| self.sametypeshape(high, low) | |||||
| low = self.cast(low, self.parameter_type) | |||||
| high = self.cast(high, self.parameter_type) | |||||
| return low, high | |||||
| @property | @property | ||||
| def low(self): | def low(self): | ||||
| """ | """ | ||||
| @@ -209,7 +186,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| range(U) = high -low | range(U) = high -low | ||||
| """ | """ | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| return high - low | return high - low | ||||
| def _mean(self, low=None, high=None): | def _mean(self, low=None, high=None): | ||||
| @@ -217,7 +194,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| MEAN(U) = \frac{low + high}{2}. | MEAN(U) = \frac{low + high}{2}. | ||||
| """ | """ | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| return (low + high) / 2. | return (low + high) / 2. | ||||
| def _var(self, low=None, high=None): | def _var(self, low=None, high=None): | ||||
| @@ -225,7 +202,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| VAR(U) = \frac{(high -low) ^ 2}{12}. | VAR(U) = \frac{(high -low) ^ 2}{12}. | ||||
| """ | """ | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| return self.sq(high - low) / 12.0 | return self.sq(high - low) / 12.0 | ||||
| def _entropy(self, low=None, high=None): | def _entropy(self, low=None, high=None): | ||||
| @@ -233,7 +210,7 @@ class Uniform(Distribution): | |||||
| .. math:: | .. math:: | ||||
| H(U) = \log(high - low). | H(U) = \log(high - low). | ||||
| """ | """ | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| return self.log(high - low) | return self.log(high - low) | ||||
| def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): | def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): | ||||
| @@ -266,7 +243,7 @@ class Uniform(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | ||||
| prob = self.exp(neg_ones * self.log(high - low)) | prob = self.exp(neg_ones * self.log(high - low)) | ||||
| broadcast_shape = self.shape(prob) | broadcast_shape = self.shape(prob) | ||||
| @@ -292,7 +269,7 @@ class Uniform(Distribution): | |||||
| low_b = self.cast(low_b, self.parameter_type) | low_b = self.cast(low_b, self.parameter_type) | ||||
| high_b = self._check_value(high_b, 'high_b') | high_b = self._check_value(high_b, 'high_b') | ||||
| high_b = self.cast(high_b, self.parameter_type) | high_b = self.cast(high_b, self.parameter_type) | ||||
| low_a, high_a = self._check_param(low, high) | |||||
| low_a, high_a = self._check_param_type(low, high) | |||||
| kl = self.log(high_b - low_b) - self.log(high_a - low_a) | kl = self.log(high_b - low_b) - self.log(high_a - low_a) | ||||
| comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) | ||||
| return self.select(comp, kl, self.log(self.zeroslike(kl))) | return self.select(comp, kl, self.log(self.zeroslike(kl))) | ||||
| @@ -313,7 +290,7 @@ class Uniform(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, self.dtype) | value = self.cast(value, self.dtype) | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| prob = (value - low) / (high - low) | prob = (value - low) / (high - low) | ||||
| broadcast_shape = self.shape(prob) | broadcast_shape = self.shape(prob) | ||||
| zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) | ||||
| @@ -336,7 +313,7 @@ class Uniform(Distribution): | |||||
| Tensor, shape is shape + batch_shape. | Tensor, shape is shape + batch_shape. | ||||
| """ | """ | ||||
| shape = self.checktuple(shape, 'shape') | shape = self.checktuple(shape, 'shape') | ||||
| low, high = self._check_param(low, high) | |||||
| low, high = self._check_param_type(low, high) | |||||
| broadcast_shape = self.shape(low + high) | broadcast_shape = self.shape(low + high) | ||||
| origin_shape = shape + broadcast_shape | origin_shape = shape + broadcast_shape | ||||
| if origin_shape == (): | if origin_shape == (): | ||||
| @@ -0,0 +1,182 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Test util functions used in distribution classes. | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore.nn.cell import Cell | |||||
| from mindspore import context | |||||
| from mindspore import dtype | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.nn.probability.distribution._utils.utils import set_param_type, \ | |||||
| cast_to_tensor, CheckTuple, CheckTensor | |||||
| def test_set_param_type(): | |||||
| """ | |||||
| Test set_param_type function. | |||||
| """ | |||||
| tensor_fp16 = Tensor(0.1, dtype=dtype.float16) | |||||
| tensor_fp32 = Tensor(0.1, dtype=dtype.float32) | |||||
| tensor_fp64 = Tensor(0.1, dtype=dtype.float64) | |||||
| tensor_int32 = Tensor(0.1, dtype=dtype.int32) | |||||
| array_fp32 = np.array(1.0).astype(np.float32) | |||||
| array_fp64 = np.array(1.0).astype(np.float64) | |||||
| array_int32 = np.array(1.0).astype(np.int32) | |||||
| dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32} | |||||
| dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64} | |||||
| dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32} | |||||
| dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32} | |||||
| dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64} | |||||
| dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32} | |||||
| dict7 = {'a': 1.0} | |||||
| dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0} | |||||
| dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16} | |||||
| dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64} | |||||
| dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64} | |||||
| ans1 = set_param_type(dict1, dtype.float16) | |||||
| assert ans1 == dtype.float32 | |||||
| with pytest.raises(TypeError): | |||||
| set_param_type(dict2, dtype.float32) | |||||
| ans3 = set_param_type(dict3, dtype.float16) | |||||
| assert ans3 == dtype.float32 | |||||
| ans4 = set_param_type(dict4, dtype.float16) | |||||
| assert ans4 == dtype.float32 | |||||
| with pytest.raises(TypeError): | |||||
| set_param_type(dict5, dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| set_param_type(dict6, dtype.float32) | |||||
| ans7 = set_param_type(dict7, dtype.float32) | |||||
| assert ans7 == dtype.float32 | |||||
| ans8 = set_param_type(dict8, dtype.float32) | |||||
| assert ans8 == dtype.float32 | |||||
| ans9 = set_param_type(dict9, dtype.float32) | |||||
| assert ans9 == dtype.float16 | |||||
| ans10 = set_param_type(dict10, dtype.float32) | |||||
| assert ans10 == dtype.float32 | |||||
| ans11 = set_param_type(dict11, dtype.float32) | |||||
| assert ans11 == dtype.float32 | |||||
| def test_cast_to_tensor(): | |||||
| """ | |||||
| Test cast_to_tensor. | |||||
| """ | |||||
| with pytest.raises(ValueError): | |||||
| cast_to_tensor(None, dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| cast_to_tensor(True, dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| cast_to_tensor({'a': 1, 'b': 2}, dtype.float32) | |||||
| with pytest.raises(TypeError): | |||||
| cast_to_tensor('tensor', dtype.float32) | |||||
| ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param')) | |||||
| assert isinstance(ans1, Parameter) | |||||
| ans2 = cast_to_tensor(np.array(1.0).astype(np.float32)) | |||||
| assert isinstance(ans2, Tensor) | |||||
| ans3 = cast_to_tensor([1.0, 2.0]) | |||||
| assert isinstance(ans3, Tensor) | |||||
| ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32) | |||||
| assert isinstance(ans4, Tensor) | |||||
| ans5 = cast_to_tensor(0.1, dtype.float32) | |||||
| assert isinstance(ans5, Tensor) | |||||
| ans6 = cast_to_tensor(1, dtype.float32) | |||||
| assert isinstance(ans6, Tensor) | |||||
| class Net(Cell): | |||||
| """ | |||||
| Test class: CheckTuple. | |||||
| """ | |||||
| def __init__(self, value): | |||||
| super(Net, self).__init__() | |||||
| self.checktuple = CheckTuple() | |||||
| self.value = value | |||||
| def construct(self, value=None): | |||||
| if value is None: | |||||
| return self.checktuple(self.value, 'input') | |||||
| return self.checktuple(value, 'input') | |||||
| def test_check_tuple(): | |||||
| """ | |||||
| Test CheckTuple. | |||||
| """ | |||||
| net1 = Net((1, 2, 3)) | |||||
| ans1 = net1() | |||||
| assert isinstance(ans1, tuple) | |||||
| with pytest.raises(TypeError): | |||||
| net2 = Net('tuple') | |||||
| net2() | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net3 = Net((1, 2, 3)) | |||||
| ans3 = net3() | |||||
| assert isinstance(ans3, tuple) | |||||
| with pytest.raises(TypeError): | |||||
| net4 = Net('tuple') | |||||
| net4() | |||||
| class Net1(Cell): | |||||
| """ | |||||
| Test class: CheckTensor. | |||||
| """ | |||||
| def __init__(self, value): | |||||
| super(Net1, self).__init__() | |||||
| self.checktensor = CheckTensor() | |||||
| self.value = value | |||||
| self.context = context.get_context('mode') | |||||
| def construct(self, value=None): | |||||
| value = self.value if value is None else value | |||||
| if self.context == 0: | |||||
| self.checktensor(value, 'input') | |||||
| return value | |||||
| return self.checktensor(value, 'input') | |||||
| def test_check_tensor(): | |||||
| """ | |||||
| Test CheckTensor. | |||||
| """ | |||||
| value = Tensor(0.1, dtype=dtype.float32) | |||||
| net1 = Net1(value) | |||||
| ans1 = net1() | |||||
| assert isinstance(ans1, Tensor) | |||||
| ans1 = net1(value) | |||||
| assert isinstance(ans1, Tensor) | |||||
| with pytest.raises(TypeError): | |||||
| net2 = Net1('tuple') | |||||
| net2() | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net3 = Net1(value) | |||||
| ans3 = net3() | |||||
| assert isinstance(ans3, Tensor) | |||||
| ans3 = net3(value) | |||||
| assert isinstance(ans3, Tensor) | |||||
| with pytest.raises(TypeError): | |||||
| net4 = Net1('tuple') | |||||
| net4() | |||||