| @@ -19,17 +19,18 @@ from .utils import * | |||||
| from .custom_ops import * | from .custom_ops import * | ||||
| __all__ = [ | __all__ = [ | ||||
| 'convert_to_batch', | |||||
| 'cast_to_tensor', | 'cast_to_tensor', | ||||
| 'check_greater', | 'check_greater', | ||||
| 'check_greater_equal_zero', | 'check_greater_equal_zero', | ||||
| 'check_greater_zero', | 'check_greater_zero', | ||||
| 'calc_broadcast_shape_from_param', | |||||
| 'check_scalar_from_param', | |||||
| 'check_prob', | 'check_prob', | ||||
| 'check_type', | 'check_type', | ||||
| 'exp_generic', | 'exp_generic', | ||||
| 'expm1_generic', | 'expm1_generic', | ||||
| 'log_generic', | 'log_generic', | ||||
| 'log1p_generic', | 'log1p_generic', | ||||
| 'broadcast_to', | |||||
| 'set_param_type', | |||||
| 'CheckTensor', | |||||
| 'CheckTuple', | |||||
| ] | ] | ||||
| @@ -72,3 +72,12 @@ def log1p_generic(x): | |||||
| Log1p ops on GPU device or when device_target == GPU. | Log1p ops on GPU device or when device_target == GPU. | ||||
| """ | """ | ||||
| return log_generic(x + 1.0) | return log_generic(x + 1.0) | ||||
| def broadcast_to(x, target): | |||||
| """ | |||||
| Broadcast x to the shape of target. | |||||
| """ | |||||
| shape = P.Shape() | |||||
| if shape(x) == shape(target): | |||||
| return x | |||||
| return P.BroadcastTo(shape(target))(x) | |||||
| @@ -19,13 +19,10 @@ from mindspore._checkparam import Validator as validator | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import _utils as utils | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.nn.probability as msp | |||||
| def cast_to_tensor(t, hint_type=mstype.float32): | def cast_to_tensor(t, hint_type=mstype.float32): | ||||
| """ | """ | ||||
| @@ -46,41 +43,13 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||||
| raise ValueError(f'Input cannot be None in cast_to_tensor') | raise ValueError(f'Input cannot be None in cast_to_tensor') | ||||
| if isinstance(t, Parameter): | if isinstance(t, Parameter): | ||||
| return t | return t | ||||
| t_type = hint_type | |||||
| if isinstance(t, Tensor): | |||||
| # convert the type of tensor to dtype | |||||
| return Tensor(t.asnumpy(), dtype=t_type) | |||||
| if isinstance(t, (list, np.ndarray)): | |||||
| return Tensor(t, dtype=t_type) | |||||
| if isinstance(t, bool): | if isinstance(t, bool): | ||||
| raise TypeError(f'Input cannot be Type Bool') | raise TypeError(f'Input cannot be Type Bool') | ||||
| if isinstance(t, (int, float)): | |||||
| return Tensor(t, dtype=t_type) | |||||
| if isinstance(t, (Tensor, np.ndarray, list, int, float)): | |||||
| return Tensor(t, dtype=hint_type) | |||||
| invalid_type = type(t) | invalid_type = type(t) | ||||
| raise TypeError( | raise TypeError( | ||||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||||
| def convert_to_batch(t, batch_shape, required_type): | |||||
| """ | |||||
| Convert a Tensor to a given batch shape. | |||||
| Args: | |||||
| t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted. | |||||
| batch_shape (tuple): desired batch shape. | |||||
| dtype (mindspore.dtype): desired dtype. | |||||
| Raises: | |||||
| RuntimeError: if the converison cannot be done. | |||||
| Returns: | |||||
| Tensor, with shape of batch_shape. | |||||
| """ | |||||
| if isinstance(t, Parameter): | |||||
| return t | |||||
| t = cast_to_tensor(t, required_type) | |||||
| return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) | |||||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}") | |||||
| def cast_type_for_device(dtype): | def cast_type_for_device(dtype): | ||||
| """ | """ | ||||
| @@ -100,54 +69,6 @@ def cast_type_for_device(dtype): | |||||
| return dtype | return dtype | ||||
| def check_scalar_from_param(params): | |||||
| """ | |||||
| Check if params are all scalars. | |||||
| Args: | |||||
| params (dict): parameters used to initialize distribution. | |||||
| Notes: String parameters are excluded. | |||||
| """ | |||||
| for value in params.values(): | |||||
| if value is None: | |||||
| continue | |||||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||||
| return params['distribution'].is_scalar_batch | |||||
| if isinstance(value, Parameter): | |||||
| return False | |||||
| if not isinstance(value, (int, float, str, type(params['dtype']))): | |||||
| return False | |||||
| return True | |||||
| def calc_broadcast_shape_from_param(params): | |||||
| """ | |||||
| Calculate the broadcast shape from params. | |||||
| Args: | |||||
| params (dict): parameters used to initialize distribution. | |||||
| Returns: | |||||
| tuple. | |||||
| """ | |||||
| broadcast_shape = [] | |||||
| for value in params.values(): | |||||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||||
| return params['distribution'].broadcast_shape | |||||
| if isinstance(value, (str, type(params['dtype']))): | |||||
| continue | |||||
| if value is None: | |||||
| return None | |||||
| if isinstance(value, Parameter): | |||||
| value_t = value.data | |||||
| else: | |||||
| value_t = cast_to_tensor(value, mstype.float32) | |||||
| broadcast_shape = utils.get_broadcast_shape( | |||||
| broadcast_shape, list(value_t.shape), params['name']) | |||||
| return tuple(broadcast_shape) | |||||
| def check_greater_equal_zero(value, name): | def check_greater_equal_zero(value, name): | ||||
| """ | """ | ||||
| Check if the given Tensor is greater zero. | Check if the given Tensor is greater zero. | ||||
| @@ -371,6 +292,9 @@ def set_param_type(args, hint_type): | |||||
| Raises: | Raises: | ||||
| TypeError: if tensors in args are not the same dtype. | TypeError: if tensors in args are not the same dtype. | ||||
| """ | """ | ||||
| int_type = mstype.int_type + mstype.uint_type | |||||
| if hint_type in int_type: | |||||
| hint_type = mstype.float32 | |||||
| common_dtype = None | common_dtype = None | ||||
| for name, arg in args.items(): | for name, arg in args.items(): | ||||
| if hasattr(arg, 'dtype'): | if hasattr(arg, 'dtype'): | ||||
| @@ -382,7 +306,6 @@ def set_param_type(args, hint_type): | |||||
| common_dtype = cur_dtype | common_dtype = cur_dtype | ||||
| elif cur_dtype != common_dtype: | elif cur_dtype != common_dtype: | ||||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | raise TypeError(f"{name} should have the same dtype as other arguments.") | ||||
| int_type = mstype.int_type + mstype.uint_type | |||||
| if common_dtype in int_type or common_dtype == mstype.float64: | if common_dtype in int_type or common_dtype == mstype.float64: | ||||
| return mstype.float32 | return mstype.float32 | ||||
| return hint_type if common_dtype is None else common_dtype | return hint_type if common_dtype is None else common_dtype | ||||
| @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type | |||||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -116,18 +116,14 @@ class Bernoulli(Distribution): | |||||
| Constructor of Bernoulli. | Constructor of Bernoulli. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||||
| if probs is not None: | |||||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||||
| check_prob(self.probs) | |||||
| else: | |||||
| self._probs = probs | |||||
| self.default_parameters = [self.probs] | |||||
| self.parameter_names = ['probs1'] | |||||
| self._probs = self._add_parameter(probs, 'probs') | |||||
| if self._probs is not None: | |||||
| check_prob(self.probs) | |||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -135,14 +131,11 @@ class Bernoulli(Distribution): | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | |||||
| self.floor = P.Floor() | self.floor = P.Floor() | ||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.less = P.Less() | self.less = P.Less() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.select = P.Select() | self.select = P.Select() | ||||
| self.sq = P.Square() | |||||
| self.sqrt = P.Sqrt() | |||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| @@ -173,9 +166,8 @@ class Bernoulli(Distribution): | |||||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | MODE(B) = 1 if probs1 > 0.5 else = 0 | ||||
| """ | """ | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| prob_type = self.dtypeop(probs1) | |||||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | |||||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | |||||
| zeros = self.fill(self.dtype, self.shape(probs1), 0.0) | |||||
| ones = self.fill(self.dtype, self.shape(probs1), 1.0) | |||||
| comp = self.less(0.5, probs1) | comp = self.less(0.5, probs1) | ||||
| return self.select(comp, ones, zeros) | return self.select(comp, ones, zeros) | ||||
| @@ -244,13 +236,13 @@ class Bernoulli(Distribution): | |||||
| value = self.cast(value, self.parameter_type) | value = self.cast(value, self.parameter_type) | ||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| prob_type = self.dtypeop(probs1) | |||||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | |||||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | |||||
| broadcast_shape_tensor = value * probs1 | |||||
| value = self.broadcast(value, broadcast_shape_tensor) | |||||
| probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) | |||||
| comp_zero = self.less(value, 0.0) | comp_zero = self.less(value, 0.0) | ||||
| comp_one = self.less(value, 1.0) | comp_one = self.less(value, 1.0) | ||||
| zeros = self.fill(prob_type, self.shape(value), 0.0) | |||||
| ones = self.fill(prob_type, self.shape(value), 1.0) | |||||
| zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) | |||||
| ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) | |||||
| less_than_zero = self.select(comp_zero, zeros, probs0) | less_than_zero = self.select(comp_zero, zeros, probs0) | ||||
| return self.select(comp_one, less_than_zero, ones) | return self.select(comp_one, less_than_zero, ones) | ||||
| @@ -14,13 +14,14 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """basic""" | """basic""" | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from mindspore.common import get_seed | from mindspore.common import get_seed | ||||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ | |||||
| raise_none_error | |||||
| from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device | |||||
| from ._utils.utils import CheckTuple, CheckTensor | from ._utils.utils import CheckTuple, CheckTensor | ||||
| from ._utils.custom_ops import broadcast_to, exp_generic, log_generic | |||||
| class Distribution(Cell): | class Distribution(Cell): | ||||
| @@ -68,14 +69,16 @@ class Distribution(Cell): | |||||
| self._seed = seed | self._seed = seed | ||||
| self._dtype = cast_type_for_device(dtype) | self._dtype = cast_type_for_device(dtype) | ||||
| self._parameters = {} | self._parameters = {} | ||||
| # parsing parameters | # parsing parameters | ||||
| for k in param.keys(): | for k in param.keys(): | ||||
| if not(k == 'self' or k.startswith('_')): | if not(k == 'self' or k.startswith('_')): | ||||
| self._parameters[k] = param[k] | self._parameters[k] = param[k] | ||||
| # some attributes | # some attributes | ||||
| self._broadcast_shape = calc_broadcast_shape_from_param( | |||||
| self.parameters) | |||||
| self._is_scalar_batch = check_scalar_from_param(self.parameters) | |||||
| self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) | |||||
| self._broadcast_shape = self._calc_broadcast_shape() | |||||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||||
| # set the function to call according to the derived class's attributes | # set the function to call according to the derived class's attributes | ||||
| self._set_prob() | self._set_prob() | ||||
| @@ -91,6 +94,18 @@ class Distribution(Cell): | |||||
| self.context_mode = context.get_context('mode') | self.context_mode = context.get_context('mode') | ||||
| self.checktuple = CheckTuple() | self.checktuple = CheckTuple() | ||||
| self.checktensor = CheckTensor() | self.checktensor = CheckTensor() | ||||
| self.broadcast = broadcast_to | |||||
| # ops needed for the base class | |||||
| self.cast_base = P.Cast() | |||||
| self.dtype_base = P.DType() | |||||
| self.exp_base = exp_generic | |||||
| self.fill_base = P.Fill() | |||||
| self.log_base = log_generic | |||||
| self.sametypeshape_base = P.SameTypeShape() | |||||
| self.sq_base = P.Square() | |||||
| self.sqrt_base = P.Sqrt() | |||||
| self.shape_base = P.Shape() | |||||
| @property | @property | ||||
| def name(self): | def name(self): | ||||
| @@ -116,6 +131,21 @@ class Distribution(Cell): | |||||
| def broadcast_shape(self): | def broadcast_shape(self): | ||||
| return self._broadcast_shape | return self._broadcast_shape | ||||
| def _add_parameter(self, value, name): | |||||
| """ | |||||
| Cast `value` to a tensor and add it to `self.default_parameters`. | |||||
| Add `name` into and `self.parameter_names`. | |||||
| """ | |||||
| # initialize the attributes if they do not exist yet | |||||
| if not hasattr(self, 'default_parameters'): | |||||
| self.default_parameters = [] | |||||
| self.parameter_names = [] | |||||
| # cast value to a tensor if it is not None | |||||
| value_t = None if value is None else cast_to_tensor(value, self.parameter_type) | |||||
| self.default_parameters += [value_t,] | |||||
| self.parameter_names += [name,] | |||||
| return value_t | |||||
| def _check_param_type(self, *args): | def _check_param_type(self, *args): | ||||
| """ | """ | ||||
| Check the availability and validity of default parameters and `dist_spec_args`. | Check the availability and validity of default parameters and `dist_spec_args`. | ||||
| @@ -123,6 +153,7 @@ class Distribution(Cell): | |||||
| are None, the parameters must be passed in through `args`. | are None, the parameters must be passed in through `args`. | ||||
| """ | """ | ||||
| broadcast_shape = None | broadcast_shape = None | ||||
| broadcast_shape_tensor = None | |||||
| common_dtype = None | common_dtype = None | ||||
| out = [] | out = [] | ||||
| @@ -139,17 +170,17 @@ class Distribution(Cell): | |||||
| # broadcast if the number of args > 1 | # broadcast if the number of args > 1 | ||||
| if broadcast_shape is None: | if broadcast_shape is None: | ||||
| broadcast_shape = self.shape(arg) | |||||
| common_dtype = self.dtypeop(arg) | |||||
| broadcast_shape = self.shape_base(arg) | |||||
| common_dtype = self.dtype_base(arg) | |||||
| broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) | |||||
| else: | else: | ||||
| ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| broadcast_shape = self.shape(arg + ones) | |||||
| broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) | |||||
| broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) | |||||
| arg = self.broadcast(arg, broadcast_shape_tensor) | |||||
| # check if the arguments have the same dtype | # check if the arguments have the same dtype | ||||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) | |||||
| self.sametypeshape(arg, dtype_tensor) | |||||
| arg = self.cast(arg, self.parameter_type) | |||||
| self.sametypeshape_base(arg, broadcast_shape_tensor) | |||||
| arg = self.cast_base(arg, self.parameter_type) | |||||
| out.append(arg) | out.append(arg) | ||||
| if len(out) == 1: | if len(out) == 1: | ||||
| @@ -158,7 +189,7 @@ class Distribution(Cell): | |||||
| # broadcast all args to broadcast_shape | # broadcast all args to broadcast_shape | ||||
| result = () | result = () | ||||
| for arg in out: | for arg in out: | ||||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||||
| arg = self.broadcast(arg, broadcast_shape_tensor) | |||||
| result = result + (arg,) | result = result + (arg,) | ||||
| return result | return result | ||||
| @@ -171,6 +202,38 @@ class Distribution(Cell): | |||||
| return value | return value | ||||
| return self.checktensor(value, name) | return self.checktensor(value, name) | ||||
| def _check_is_scalar_batch(self): | |||||
| """ | |||||
| Check if the parameters used during initialization are scalars. | |||||
| """ | |||||
| if hasattr(self, 'distribution'): | |||||
| return self._distribution.is_scalar_batch | |||||
| param_dict = self.parameters['param_dict'] | |||||
| for value in param_dict.values(): | |||||
| if value is None: | |||||
| continue | |||||
| if not isinstance(value, (int, float)): | |||||
| return False | |||||
| return True | |||||
| def _calc_broadcast_shape(self): | |||||
| """ | |||||
| Calculate the broadcast shape of the parameters used during initialization. | |||||
| """ | |||||
| if hasattr(self, 'distribution'): | |||||
| return self._distribution.broadcast_shape | |||||
| param_dict = self.parameters['param_dict'] | |||||
| broadcast_shape_tensor = None | |||||
| for value in param_dict.values(): | |||||
| if value is None: | |||||
| return None | |||||
| if broadcast_shape_tensor is None: | |||||
| broadcast_shape_tensor = cast_to_tensor(value) | |||||
| else: | |||||
| value = cast_to_tensor(value) | |||||
| broadcast_shape_tensor = (value + broadcast_shape_tensor) | |||||
| return broadcast_shape_tensor.shape | |||||
| def _set_prob(self): | def _set_prob(self): | ||||
| """ | """ | ||||
| Set probability funtion based on the availability of `_prob` and `_log_likehood`. | Set probability funtion based on the availability of `_prob` and `_log_likehood`. | ||||
| @@ -280,7 +343,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| probability(x) = \exp(log_likehood(x)) | probability(x) = \exp(log_likehood(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_prob(value, *args, **kwargs)) | |||||
| return self.exp_base(self._log_prob(value, *args, **kwargs)) | |||||
| def prob(self, value, *args, **kwargs): | def prob(self, value, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -304,7 +367,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| log_prob(x) = \log(prob(x)) | log_prob(x) = \log(prob(x)) | ||||
| """ | """ | ||||
| return self.log(self._prob(value, *args, **kwargs)) | |||||
| return self.log_base(self._prob(value, *args, **kwargs)) | |||||
| def cdf(self, value, *args, **kwargs): | def cdf(self, value, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -328,7 +391,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| cdf(x) = \exp(log_cdf(x)) | cdf(x) = \exp(log_cdf(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_cdf(value, *args, **kwargs)) | |||||
| return self.exp_base(self._log_cdf(value, *args, **kwargs)) | |||||
| def _calc_cdf_from_survival(self, value, *args, **kwargs): | def _calc_cdf_from_survival(self, value, *args, **kwargs): | ||||
| r""" | r""" | ||||
| @@ -346,7 +409,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| cdf(x) = 1 - (\exp(log_survival(x))) | cdf(x) = 1 - (\exp(log_survival(x))) | ||||
| """ | """ | ||||
| return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) | |||||
| return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs)) | |||||
| def log_cdf(self, value, *args, **kwargs): | def log_cdf(self, value, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -370,7 +433,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| log_cdf(x) = \log(cdf(x)) | log_cdf(x) = \log(cdf(x)) | ||||
| """ | """ | ||||
| return self.log(self._call_cdf(value, *args, **kwargs)) | |||||
| return self.log_base(self._call_cdf(value, *args, **kwargs)) | |||||
| def survival_function(self, value, *args, **kwargs): | def survival_function(self, value, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -403,7 +466,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| survival(x) = \exp(survival_function(x)) | survival(x) = \exp(survival_function(x)) | ||||
| """ | """ | ||||
| return self.exp(self._log_survival(value, *args, **kwargs)) | |||||
| return self.exp_base(self._log_survival(value, *args, **kwargs)) | |||||
| def log_survival(self, value, *args, **kwargs): | def log_survival(self, value, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -427,7 +490,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| log_survival(x) = \log(survival_function(x)) | log_survival(x) = \log(survival_function(x)) | ||||
| """ | """ | ||||
| return self.log(self._call_survival(value, *args, **kwargs)) | |||||
| return self.log_base(self._call_survival(value, *args, **kwargs)) | |||||
| def kl_loss(self, dist, *args, **kwargs): | def kl_loss(self, dist, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -507,7 +570,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| STD(x) = \sqrt(VAR(x)) | STD(x) = \sqrt(VAR(x)) | ||||
| """ | """ | ||||
| return self.sqrt(self._var(*args, **kwargs)) | |||||
| return self.sqrt_base(self._var(*args, **kwargs)) | |||||
| def _calc_var_from_sd(self, *args, **kwargs): | def _calc_var_from_sd(self, *args, **kwargs): | ||||
| r""" | r""" | ||||
| @@ -516,7 +579,7 @@ class Distribution(Cell): | |||||
| .. math:: | .. math:: | ||||
| VAR(x) = STD(x) ^ 2 | VAR(x) = STD(x) ^ 2 | ||||
| """ | """ | ||||
| return self.sq(self._sd(*args, **kwargs)) | |||||
| return self.sq_base(self._sd(*args, **kwargs)) | |||||
| def entropy(self, *args, **kwargs): | def entropy(self, *args, **kwargs): | ||||
| """ | """ | ||||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type | |||||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -118,18 +118,14 @@ class Exponential(Distribution): | |||||
| Constructor of Exponential. | Constructor of Exponential. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'rate': rate} | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Exponential, self).__init__(seed, dtype, name, param) | super(Exponential, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = set_param_type({'rate': rate}, self.dtype) | |||||
| if rate is not None: | |||||
| self._rate = cast_to_tensor(rate, self.parameter_type) | |||||
| check_greater_zero(self._rate, "rate") | |||||
| else: | |||||
| self._rate = rate | |||||
| self.default_parameters = [self.rate] | |||||
| self.parameter_names = ['rate'] | |||||
| self._rate = self._add_parameter(rate, 'rate') | |||||
| if self.rate is not None: | |||||
| check_greater_zero(self.rate, 'rate') | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| @@ -144,8 +140,6 @@ class Exponential(Distribution): | |||||
| self.less = P.Less() | self.less = P.Less() | ||||
| self.select = P.Select() | self.select = P.Select() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sqrt = P.Sqrt() | |||||
| self.sq = P.Square() | |||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | |||||
| set_param_type | |||||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -120,18 +119,14 @@ class Geometric(Distribution): | |||||
| Constructor of Geometric distribution. | Constructor of Geometric distribution. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | |||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Geometric, self).__init__(seed, dtype, name, param) | super(Geometric, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||||
| if probs is not None: | |||||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||||
| check_prob(self._probs) | |||||
| else: | |||||
| self._probs = probs | |||||
| self.default_parameters = [self.probs] | |||||
| self.parameter_names = ['probs1'] | |||||
| self._probs = self._add_parameter(probs, 'probs') | |||||
| if self._probs is not None: | |||||
| check_prob(self.probs) | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| @@ -150,7 +145,6 @@ class Geometric(Distribution): | |||||
| self.select = P.Select() | self.select = P.Select() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | |||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| @@ -181,7 +175,7 @@ class Geometric(Distribution): | |||||
| MODE(Geo) = 0 | MODE(Geo) = 0 | ||||
| """ | """ | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | |||||
| return self.fill(self.dtype, self.shape(probs1), 0.) | |||||
| def _var(self, probs1=None): | def _var(self, probs1=None): | ||||
| r""" | r""" | ||||
| @@ -229,7 +223,7 @@ class Geometric(Distribution): | |||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | |||||
| zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0) | |||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, pmf) | return self.select(comp, zeros, pmf) | ||||
| @@ -252,7 +246,7 @@ class Geometric(Distribution): | |||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | cdf = 1.0 - self.pow(probs0, value + 1.0) | ||||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | |||||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | |||||
| comp = self.less(value, zeros) | comp = self.less(value, zeros) | ||||
| return self.select(comp, zeros, cdf) | return self.select(comp, zeros, cdf) | ||||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||||
| set_param_type | |||||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | ||||
| @@ -125,23 +124,15 @@ class Normal(Distribution): | |||||
| Constructor of Normal. | Constructor of Normal. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'mean': mean, 'sd': sd} | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Normal, self).__init__(seed, dtype, name, param) | super(Normal, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = set_param_type( | |||||
| {'mean': mean, 'sd': sd}, self.dtype) | |||||
| if mean is not None and sd is not None: | |||||
| self._mean_value = cast_to_tensor(mean, self.parameter_type) | |||||
| self._sd_value = cast_to_tensor(sd, self.parameter_type) | |||||
| check_greater_zero(self._sd_value, "Standard deviation") | |||||
| else: | |||||
| self._mean_value = mean if mean is None else cast_to_tensor( | |||||
| mean, self.parameter_type) | |||||
| self._sd_value = sd if sd is None else cast_to_tensor( | |||||
| sd, self.parameter_type) | |||||
| self.default_parameters = [self._mean_value, self._sd_value] | |||||
| self.parameter_names = ['mean', 'sd'] | |||||
| self._mean_value = self._add_parameter(mean, 'mean') | |||||
| self._sd_value = self._add_parameter(sd, 'sd') | |||||
| if self._sd_value is not None: | |||||
| check_greater_zero(self._sd_value, "Standard deviation") | |||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -151,13 +142,9 @@ class Normal(Distribution): | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.fill = P.Fill() | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.zeroslike = P.ZerosLike() | |||||
| self.dtypeop = P.DType() | |||||
| self.sametypeshape = P.SameTypeShape() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| @@ -81,6 +81,8 @@ class TransformedDistribution(Distribution): | |||||
| self._bijector = bijector | self._bijector = bijector | ||||
| self._distribution = distribution | self._distribution = distribution | ||||
| self._is_linear_transformation = bijector.is_constant_jacobian | self._is_linear_transformation = bijector.is_constant_jacobian | ||||
| self.default_parameters = distribution.default_parameters | |||||
| self.parameter_names = distribution.parameter_names | |||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| self.log = log_generic | self.log = log_generic | ||||
| @@ -17,8 +17,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | |||||
| set_param_type | |||||
| from ._utils.utils import check_greater, check_type, check_distribution_name | |||||
| from ._utils.custom_ops import exp_generic, log_generic | from ._utils.custom_ops import exp_generic, log_generic | ||||
| @@ -124,23 +123,16 @@ class Uniform(Distribution): | |||||
| Constructor of Uniform distribution. | Constructor of Uniform distribution. | ||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'low': low, 'high': high} | |||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| check_type(dtype, valid_dtype, type(self).__name__) | check_type(dtype, valid_dtype, type(self).__name__) | ||||
| super(Uniform, self).__init__(seed, dtype, name, param) | super(Uniform, self).__init__(seed, dtype, name, param) | ||||
| self.parameter_type = set_param_type( | |||||
| {'low': low, 'high': high}, self.dtype) | |||||
| if low is not None and high is not None: | |||||
| self._low = cast_to_tensor(low, self.parameter_type) | |||||
| self._high = cast_to_tensor(high, self.parameter_type) | |||||
| check_greater(self.low, self.high, "low value", "high value") | |||||
| else: | |||||
| self._low = low if low is None else cast_to_tensor( | |||||
| low, self.parameter_type) | |||||
| self._high = high if high is None else cast_to_tensor( | |||||
| high, self.parameter_type) | |||||
| self.default_parameters = [self.low, self.high] | |||||
| self.parameter_names = ['low', 'high'] | |||||
| self._low = self._add_parameter(low, 'low') | |||||
| self._high = self._add_parameter(high, 'high') | |||||
| if self.low is not None and self.high is not None: | |||||
| check_greater(self.low, self.high, 'low', 'high') | |||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_generic | self.exp = exp_generic | ||||
| @@ -156,12 +148,9 @@ class Uniform(Distribution): | |||||
| self.select = P.Select() | self.select = P.Select() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| self.sqrt = P.Sqrt() | |||||
| self.zeroslike = P.ZerosLike() | self.zeroslike = P.ZerosLike() | ||||
| self.uniform = C.uniform | self.uniform = C.uniform | ||||
| self.sametypeshape = P.SameTypeShape() | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| if self.is_scalar_batch: | if self.is_scalar_batch: | ||||
| str_info = f'low = {self.low}, high = {self.high}' | str_info = f'low = {self.low}, high = {self.high}' | ||||