| @@ -19,17 +19,18 @@ from .utils import * | |||
| from .custom_ops import * | |||
| __all__ = [ | |||
| 'convert_to_batch', | |||
| 'cast_to_tensor', | |||
| 'check_greater', | |||
| 'check_greater_equal_zero', | |||
| 'check_greater_zero', | |||
| 'calc_broadcast_shape_from_param', | |||
| 'check_scalar_from_param', | |||
| 'check_prob', | |||
| 'check_type', | |||
| 'exp_generic', | |||
| 'expm1_generic', | |||
| 'log_generic', | |||
| 'log1p_generic', | |||
| 'broadcast_to', | |||
| 'set_param_type', | |||
| 'CheckTensor', | |||
| 'CheckTuple', | |||
| ] | |||
| @@ -72,3 +72,12 @@ def log1p_generic(x): | |||
| Log1p ops on GPU device or when device_target == GPU. | |||
| """ | |||
| return log_generic(x + 1.0) | |||
| def broadcast_to(x, target): | |||
| """ | |||
| Broadcast x to the shape of target. | |||
| """ | |||
| shape = P.Shape() | |||
| if shape(x) == shape(target): | |||
| return x | |||
| return P.BroadcastTo(shape(target))(x) | |||
| @@ -19,13 +19,10 @@ from mindspore._checkparam import Validator as validator | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import _utils as utils | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register | |||
| import mindspore.nn as nn | |||
| import mindspore.nn.probability as msp | |||
| def cast_to_tensor(t, hint_type=mstype.float32): | |||
| """ | |||
| @@ -46,41 +43,13 @@ def cast_to_tensor(t, hint_type=mstype.float32): | |||
| raise ValueError(f'Input cannot be None in cast_to_tensor') | |||
| if isinstance(t, Parameter): | |||
| return t | |||
| t_type = hint_type | |||
| if isinstance(t, Tensor): | |||
| # convert the type of tensor to dtype | |||
| return Tensor(t.asnumpy(), dtype=t_type) | |||
| if isinstance(t, (list, np.ndarray)): | |||
| return Tensor(t, dtype=t_type) | |||
| if isinstance(t, bool): | |||
| raise TypeError(f'Input cannot be Type Bool') | |||
| if isinstance(t, (int, float)): | |||
| return Tensor(t, dtype=t_type) | |||
| if isinstance(t, (Tensor, np.ndarray, list, int, float)): | |||
| return Tensor(t, dtype=hint_type) | |||
| invalid_type = type(t) | |||
| raise TypeError( | |||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") | |||
| def convert_to_batch(t, batch_shape, required_type): | |||
| """ | |||
| Convert a Tensor to a given batch shape. | |||
| Args: | |||
| t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted. | |||
| batch_shape (tuple): desired batch shape. | |||
| dtype (mindspore.dtype): desired dtype. | |||
| Raises: | |||
| RuntimeError: if the converison cannot be done. | |||
| Returns: | |||
| Tensor, with shape of batch_shape. | |||
| """ | |||
| if isinstance(t, Parameter): | |||
| return t | |||
| t = cast_to_tensor(t, required_type) | |||
| return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) | |||
| f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}") | |||
| def cast_type_for_device(dtype): | |||
| """ | |||
| @@ -100,54 +69,6 @@ def cast_type_for_device(dtype): | |||
| return dtype | |||
| def check_scalar_from_param(params): | |||
| """ | |||
| Check if params are all scalars. | |||
| Args: | |||
| params (dict): parameters used to initialize distribution. | |||
| Notes: String parameters are excluded. | |||
| """ | |||
| for value in params.values(): | |||
| if value is None: | |||
| continue | |||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||
| return params['distribution'].is_scalar_batch | |||
| if isinstance(value, Parameter): | |||
| return False | |||
| if not isinstance(value, (int, float, str, type(params['dtype']))): | |||
| return False | |||
| return True | |||
| def calc_broadcast_shape_from_param(params): | |||
| """ | |||
| Calculate the broadcast shape from params. | |||
| Args: | |||
| params (dict): parameters used to initialize distribution. | |||
| Returns: | |||
| tuple. | |||
| """ | |||
| broadcast_shape = [] | |||
| for value in params.values(): | |||
| if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): | |||
| return params['distribution'].broadcast_shape | |||
| if isinstance(value, (str, type(params['dtype']))): | |||
| continue | |||
| if value is None: | |||
| return None | |||
| if isinstance(value, Parameter): | |||
| value_t = value.data | |||
| else: | |||
| value_t = cast_to_tensor(value, mstype.float32) | |||
| broadcast_shape = utils.get_broadcast_shape( | |||
| broadcast_shape, list(value_t.shape), params['name']) | |||
| return tuple(broadcast_shape) | |||
| def check_greater_equal_zero(value, name): | |||
| """ | |||
| Check if the given Tensor is greater zero. | |||
| @@ -371,6 +292,9 @@ def set_param_type(args, hint_type): | |||
| Raises: | |||
| TypeError: if tensors in args are not the same dtype. | |||
| """ | |||
| int_type = mstype.int_type + mstype.uint_type | |||
| if hint_type in int_type: | |||
| hint_type = mstype.float32 | |||
| common_dtype = None | |||
| for name, arg in args.items(): | |||
| if hasattr(arg, 'dtype'): | |||
| @@ -382,7 +306,6 @@ def set_param_type(args, hint_type): | |||
| common_dtype = cur_dtype | |||
| elif cur_dtype != common_dtype: | |||
| raise TypeError(f"{name} should have the same dtype as other arguments.") | |||
| int_type = mstype.int_type + mstype.uint_type | |||
| if common_dtype in int_type or common_dtype == mstype.float64: | |||
| return mstype.float32 | |||
| return hint_type if common_dtype is None else common_dtype | |||
| @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type | |||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -116,18 +116,14 @@ class Bernoulli(Distribution): | |||
| Constructor of Bernoulli. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'probs': probs} | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||
| if probs is not None: | |||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||
| check_prob(self.probs) | |||
| else: | |||
| self._probs = probs | |||
| self.default_parameters = [self.probs] | |||
| self.parameter_names = ['probs1'] | |||
| self._probs = self._add_parameter(probs, 'probs') | |||
| if self._probs is not None: | |||
| check_prob(self.probs) | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| @@ -135,14 +131,11 @@ class Bernoulli(Distribution): | |||
| self.squeeze = P.Squeeze(0) | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.dtypeop = P.DType() | |||
| self.floor = P.Floor() | |||
| self.fill = P.Fill() | |||
| self.less = P.Less() | |||
| self.shape = P.Shape() | |||
| self.select = P.Select() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| @@ -173,9 +166,8 @@ class Bernoulli(Distribution): | |||
| MODE(B) = 1 if probs1 > 0.5 else = 0 | |||
| """ | |||
| probs1 = self._check_param_type(probs1) | |||
| prob_type = self.dtypeop(probs1) | |||
| zeros = self.fill(prob_type, self.shape(probs1), 0.0) | |||
| ones = self.fill(prob_type, self.shape(probs1), 1.0) | |||
| zeros = self.fill(self.dtype, self.shape(probs1), 0.0) | |||
| ones = self.fill(self.dtype, self.shape(probs1), 1.0) | |||
| comp = self.less(0.5, probs1) | |||
| return self.select(comp, ones, zeros) | |||
| @@ -244,13 +236,13 @@ class Bernoulli(Distribution): | |||
| value = self.cast(value, self.parameter_type) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param_type(probs1) | |||
| prob_type = self.dtypeop(probs1) | |||
| value = value * self.fill(prob_type, self.shape(probs1), 1.0) | |||
| probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) | |||
| broadcast_shape_tensor = value * probs1 | |||
| value = self.broadcast(value, broadcast_shape_tensor) | |||
| probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) | |||
| comp_zero = self.less(value, 0.0) | |||
| comp_one = self.less(value, 1.0) | |||
| zeros = self.fill(prob_type, self.shape(value), 0.0) | |||
| ones = self.fill(prob_type, self.shape(value), 1.0) | |||
| zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) | |||
| ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) | |||
| less_than_zero = self.select(comp_zero, zeros, probs0) | |||
| return self.select(comp_one, less_than_zero, ones) | |||
| @@ -14,13 +14,14 @@ | |||
| # ============================================================================ | |||
| """basic""" | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common import get_seed | |||
| from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ | |||
| raise_none_error | |||
| from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device | |||
| from ._utils.utils import CheckTuple, CheckTensor | |||
| from ._utils.custom_ops import broadcast_to, exp_generic, log_generic | |||
| class Distribution(Cell): | |||
| @@ -68,14 +69,16 @@ class Distribution(Cell): | |||
| self._seed = seed | |||
| self._dtype = cast_type_for_device(dtype) | |||
| self._parameters = {} | |||
| # parsing parameters | |||
| for k in param.keys(): | |||
| if not(k == 'self' or k.startswith('_')): | |||
| self._parameters[k] = param[k] | |||
| # some attributes | |||
| self._broadcast_shape = calc_broadcast_shape_from_param( | |||
| self.parameters) | |||
| self._is_scalar_batch = check_scalar_from_param(self.parameters) | |||
| self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) | |||
| self._broadcast_shape = self._calc_broadcast_shape() | |||
| self._is_scalar_batch = self._check_is_scalar_batch() | |||
| # set the function to call according to the derived class's attributes | |||
| self._set_prob() | |||
| @@ -91,6 +94,18 @@ class Distribution(Cell): | |||
| self.context_mode = context.get_context('mode') | |||
| self.checktuple = CheckTuple() | |||
| self.checktensor = CheckTensor() | |||
| self.broadcast = broadcast_to | |||
| # ops needed for the base class | |||
| self.cast_base = P.Cast() | |||
| self.dtype_base = P.DType() | |||
| self.exp_base = exp_generic | |||
| self.fill_base = P.Fill() | |||
| self.log_base = log_generic | |||
| self.sametypeshape_base = P.SameTypeShape() | |||
| self.sq_base = P.Square() | |||
| self.sqrt_base = P.Sqrt() | |||
| self.shape_base = P.Shape() | |||
| @property | |||
| def name(self): | |||
| @@ -116,6 +131,21 @@ class Distribution(Cell): | |||
| def broadcast_shape(self): | |||
| return self._broadcast_shape | |||
| def _add_parameter(self, value, name): | |||
| """ | |||
| Cast `value` to a tensor and add it to `self.default_parameters`. | |||
| Add `name` into and `self.parameter_names`. | |||
| """ | |||
| # initialize the attributes if they do not exist yet | |||
| if not hasattr(self, 'default_parameters'): | |||
| self.default_parameters = [] | |||
| self.parameter_names = [] | |||
| # cast value to a tensor if it is not None | |||
| value_t = None if value is None else cast_to_tensor(value, self.parameter_type) | |||
| self.default_parameters += [value_t,] | |||
| self.parameter_names += [name,] | |||
| return value_t | |||
| def _check_param_type(self, *args): | |||
| """ | |||
| Check the availability and validity of default parameters and `dist_spec_args`. | |||
| @@ -123,6 +153,7 @@ class Distribution(Cell): | |||
| are None, the parameters must be passed in through `args`. | |||
| """ | |||
| broadcast_shape = None | |||
| broadcast_shape_tensor = None | |||
| common_dtype = None | |||
| out = [] | |||
| @@ -139,17 +170,17 @@ class Distribution(Cell): | |||
| # broadcast if the number of args > 1 | |||
| if broadcast_shape is None: | |||
| broadcast_shape = self.shape(arg) | |||
| common_dtype = self.dtypeop(arg) | |||
| broadcast_shape = self.shape_base(arg) | |||
| common_dtype = self.dtype_base(arg) | |||
| broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) | |||
| else: | |||
| ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| broadcast_shape = self.shape(arg + ones) | |||
| broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) | |||
| broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) | |||
| arg = self.broadcast(arg, broadcast_shape_tensor) | |||
| # check if the arguments have the same dtype | |||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) | |||
| self.sametypeshape(arg, dtype_tensor) | |||
| arg = self.cast(arg, self.parameter_type) | |||
| self.sametypeshape_base(arg, broadcast_shape_tensor) | |||
| arg = self.cast_base(arg, self.parameter_type) | |||
| out.append(arg) | |||
| if len(out) == 1: | |||
| @@ -158,7 +189,7 @@ class Distribution(Cell): | |||
| # broadcast all args to broadcast_shape | |||
| result = () | |||
| for arg in out: | |||
| arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) | |||
| arg = self.broadcast(arg, broadcast_shape_tensor) | |||
| result = result + (arg,) | |||
| return result | |||
| @@ -171,6 +202,38 @@ class Distribution(Cell): | |||
| return value | |||
| return self.checktensor(value, name) | |||
| def _check_is_scalar_batch(self): | |||
| """ | |||
| Check if the parameters used during initialization are scalars. | |||
| """ | |||
| if hasattr(self, 'distribution'): | |||
| return self._distribution.is_scalar_batch | |||
| param_dict = self.parameters['param_dict'] | |||
| for value in param_dict.values(): | |||
| if value is None: | |||
| continue | |||
| if not isinstance(value, (int, float)): | |||
| return False | |||
| return True | |||
| def _calc_broadcast_shape(self): | |||
| """ | |||
| Calculate the broadcast shape of the parameters used during initialization. | |||
| """ | |||
| if hasattr(self, 'distribution'): | |||
| return self._distribution.broadcast_shape | |||
| param_dict = self.parameters['param_dict'] | |||
| broadcast_shape_tensor = None | |||
| for value in param_dict.values(): | |||
| if value is None: | |||
| return None | |||
| if broadcast_shape_tensor is None: | |||
| broadcast_shape_tensor = cast_to_tensor(value) | |||
| else: | |||
| value = cast_to_tensor(value) | |||
| broadcast_shape_tensor = (value + broadcast_shape_tensor) | |||
| return broadcast_shape_tensor.shape | |||
| def _set_prob(self): | |||
| """ | |||
| Set probability funtion based on the availability of `_prob` and `_log_likehood`. | |||
| @@ -280,7 +343,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| probability(x) = \exp(log_likehood(x)) | |||
| """ | |||
| return self.exp(self._log_prob(value, *args, **kwargs)) | |||
| return self.exp_base(self._log_prob(value, *args, **kwargs)) | |||
| def prob(self, value, *args, **kwargs): | |||
| """ | |||
| @@ -304,7 +367,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| log_prob(x) = \log(prob(x)) | |||
| """ | |||
| return self.log(self._prob(value, *args, **kwargs)) | |||
| return self.log_base(self._prob(value, *args, **kwargs)) | |||
| def cdf(self, value, *args, **kwargs): | |||
| """ | |||
| @@ -328,7 +391,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| cdf(x) = \exp(log_cdf(x)) | |||
| """ | |||
| return self.exp(self._log_cdf(value, *args, **kwargs)) | |||
| return self.exp_base(self._log_cdf(value, *args, **kwargs)) | |||
| def _calc_cdf_from_survival(self, value, *args, **kwargs): | |||
| r""" | |||
| @@ -346,7 +409,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| cdf(x) = 1 - (\exp(log_survival(x))) | |||
| """ | |||
| return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) | |||
| return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs)) | |||
| def log_cdf(self, value, *args, **kwargs): | |||
| """ | |||
| @@ -370,7 +433,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| log_cdf(x) = \log(cdf(x)) | |||
| """ | |||
| return self.log(self._call_cdf(value, *args, **kwargs)) | |||
| return self.log_base(self._call_cdf(value, *args, **kwargs)) | |||
| def survival_function(self, value, *args, **kwargs): | |||
| """ | |||
| @@ -403,7 +466,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| survival(x) = \exp(survival_function(x)) | |||
| """ | |||
| return self.exp(self._log_survival(value, *args, **kwargs)) | |||
| return self.exp_base(self._log_survival(value, *args, **kwargs)) | |||
| def log_survival(self, value, *args, **kwargs): | |||
| """ | |||
| @@ -427,7 +490,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| log_survival(x) = \log(survival_function(x)) | |||
| """ | |||
| return self.log(self._call_survival(value, *args, **kwargs)) | |||
| return self.log_base(self._call_survival(value, *args, **kwargs)) | |||
| def kl_loss(self, dist, *args, **kwargs): | |||
| """ | |||
| @@ -507,7 +570,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| STD(x) = \sqrt(VAR(x)) | |||
| """ | |||
| return self.sqrt(self._var(*args, **kwargs)) | |||
| return self.sqrt_base(self._var(*args, **kwargs)) | |||
| def _calc_var_from_sd(self, *args, **kwargs): | |||
| r""" | |||
| @@ -516,7 +579,7 @@ class Distribution(Cell): | |||
| .. math:: | |||
| VAR(x) = STD(x) ^ 2 | |||
| """ | |||
| return self.sq(self._sd(*args, **kwargs)) | |||
| return self.sq_base(self._sd(*args, **kwargs)) | |||
| def entropy(self, *args, **kwargs): | |||
| """ | |||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type | |||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -118,18 +118,14 @@ class Exponential(Distribution): | |||
| Constructor of Exponential. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'rate': rate} | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Exponential, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = set_param_type({'rate': rate}, self.dtype) | |||
| if rate is not None: | |||
| self._rate = cast_to_tensor(rate, self.parameter_type) | |||
| check_greater_zero(self._rate, "rate") | |||
| else: | |||
| self._rate = rate | |||
| self.default_parameters = [self.rate] | |||
| self.parameter_names = ['rate'] | |||
| self._rate = self._add_parameter(rate, 'rate') | |||
| if self.rate is not None: | |||
| check_greater_zero(self.rate, 'rate') | |||
| self.minval = np.finfo(np.float).tiny | |||
| @@ -144,8 +140,6 @@ class Exponential(Distribution): | |||
| self.less = P.Less() | |||
| self.select = P.Select() | |||
| self.shape = P.Shape() | |||
| self.sqrt = P.Sqrt() | |||
| self.sq = P.Square() | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | |||
| set_param_type | |||
| from ._utils.utils import check_prob, check_type, check_distribution_name | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -120,18 +119,14 @@ class Geometric(Distribution): | |||
| Constructor of Geometric distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'probs': probs} | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Geometric, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) | |||
| if probs is not None: | |||
| self._probs = cast_to_tensor(probs, self.parameter_type) | |||
| check_prob(self._probs) | |||
| else: | |||
| self._probs = probs | |||
| self.default_parameters = [self.probs] | |||
| self.parameter_names = ['probs1'] | |||
| self._probs = self._add_parameter(probs, 'probs') | |||
| if self._probs is not None: | |||
| check_prob(self.probs) | |||
| self.minval = np.finfo(np.float).tiny | |||
| @@ -150,7 +145,6 @@ class Geometric(Distribution): | |||
| self.select = P.Select() | |||
| self.shape = P.Shape() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.uniform = C.uniform | |||
| def extend_repr(self): | |||
| @@ -181,7 +175,7 @@ class Geometric(Distribution): | |||
| MODE(Geo) = 0 | |||
| """ | |||
| probs1 = self._check_param_type(probs1) | |||
| return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) | |||
| return self.fill(self.dtype, self.shape(probs1), 0.) | |||
| def _var(self, probs1=None): | |||
| r""" | |||
| @@ -229,7 +223,7 @@ class Geometric(Distribution): | |||
| value = self.floor(value) | |||
| probs1 = self._check_param_type(probs1) | |||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | |||
| zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) | |||
| zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| return self.select(comp, zeros, pmf) | |||
| @@ -252,7 +246,7 @@ class Geometric(Distribution): | |||
| probs1 = self._check_param_type(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| cdf = 1.0 - self.pow(probs0, value + 1.0) | |||
| zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) | |||
| zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) | |||
| comp = self.less(value, zeros) | |||
| return self.select(comp, zeros, cdf) | |||
| @@ -18,8 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||
| set_param_type | |||
| from ._utils.utils import check_greater_zero, check_type, check_distribution_name | |||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic | |||
| @@ -125,23 +124,15 @@ class Normal(Distribution): | |||
| Constructor of Normal. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'mean': mean, 'sd': sd} | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Normal, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = set_param_type( | |||
| {'mean': mean, 'sd': sd}, self.dtype) | |||
| if mean is not None and sd is not None: | |||
| self._mean_value = cast_to_tensor(mean, self.parameter_type) | |||
| self._sd_value = cast_to_tensor(sd, self.parameter_type) | |||
| check_greater_zero(self._sd_value, "Standard deviation") | |||
| else: | |||
| self._mean_value = mean if mean is None else cast_to_tensor( | |||
| mean, self.parameter_type) | |||
| self._sd_value = sd if sd is None else cast_to_tensor( | |||
| sd, self.parameter_type) | |||
| self.default_parameters = [self._mean_value, self._sd_value] | |||
| self.parameter_names = ['mean', 'sd'] | |||
| self._mean_value = self._add_parameter(mean, 'mean') | |||
| self._sd_value = self._add_parameter(sd, 'sd') | |||
| if self._sd_value is not None: | |||
| check_greater_zero(self._sd_value, "Standard deviation") | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| @@ -151,13 +142,9 @@ class Normal(Distribution): | |||
| self.squeeze = P.Squeeze(0) | |||
| self.cast = P.Cast() | |||
| self.const = P.ScalarToArray() | |||
| self.fill = P.Fill() | |||
| self.shape = P.Shape() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.zeroslike = P.ZerosLike() | |||
| self.dtypeop = P.DType() | |||
| self.sametypeshape = P.SameTypeShape() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| @@ -81,6 +81,8 @@ class TransformedDistribution(Distribution): | |||
| self._bijector = bijector | |||
| self._distribution = distribution | |||
| self._is_linear_transformation = bijector.is_constant_jacobian | |||
| self.default_parameters = distribution.default_parameters | |||
| self.parameter_names = distribution.parameter_names | |||
| self.exp = exp_generic | |||
| self.log = log_generic | |||
| @@ -17,8 +17,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | |||
| set_param_type | |||
| from ._utils.utils import check_greater, check_type, check_distribution_name | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -124,23 +123,16 @@ class Uniform(Distribution): | |||
| Constructor of Uniform distribution. | |||
| """ | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'low': low, 'high': high} | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Uniform, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = set_param_type( | |||
| {'low': low, 'high': high}, self.dtype) | |||
| if low is not None and high is not None: | |||
| self._low = cast_to_tensor(low, self.parameter_type) | |||
| self._high = cast_to_tensor(high, self.parameter_type) | |||
| check_greater(self.low, self.high, "low value", "high value") | |||
| else: | |||
| self._low = low if low is None else cast_to_tensor( | |||
| low, self.parameter_type) | |||
| self._high = high if high is None else cast_to_tensor( | |||
| high, self.parameter_type) | |||
| self.default_parameters = [self.low, self.high] | |||
| self.parameter_names = ['low', 'high'] | |||
| self._low = self._add_parameter(low, 'low') | |||
| self._high = self._add_parameter(high, 'high') | |||
| if self.low is not None and self.high is not None: | |||
| check_greater(self.low, self.high, 'low', 'high') | |||
| # ops needed for the class | |||
| self.exp = exp_generic | |||
| @@ -156,12 +148,9 @@ class Uniform(Distribution): | |||
| self.select = P.Select() | |||
| self.shape = P.Shape() | |||
| self.sq = P.Square() | |||
| self.sqrt = P.Sqrt() | |||
| self.zeroslike = P.ZerosLike() | |||
| self.uniform = C.uniform | |||
| self.sametypeshape = P.SameTypeShape() | |||
| def extend_repr(self): | |||
| if self.is_scalar_batch: | |||
| str_info = f'low = {self.low}, high = {self.high}' | |||