diff --git a/mindspore/nn/probability/distribution/_utils/__init__.py b/mindspore/nn/probability/distribution/_utils/__init__.py index 4d2b29c064..49150d99a6 100644 --- a/mindspore/nn/probability/distribution/_utils/__init__.py +++ b/mindspore/nn/probability/distribution/_utils/__init__.py @@ -19,17 +19,18 @@ from .utils import * from .custom_ops import * __all__ = [ - 'convert_to_batch', 'cast_to_tensor', 'check_greater', 'check_greater_equal_zero', 'check_greater_zero', - 'calc_broadcast_shape_from_param', - 'check_scalar_from_param', 'check_prob', 'check_type', 'exp_generic', 'expm1_generic', 'log_generic', 'log1p_generic', + 'broadcast_to', + 'set_param_type', + 'CheckTensor', + 'CheckTuple', ] diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py index bda3ae3eaa..3bc7c3e0fd 100644 --- a/mindspore/nn/probability/distribution/_utils/custom_ops.py +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -72,3 +72,12 @@ def log1p_generic(x): Log1p ops on GPU device or when device_target == GPU. """ return log_generic(x + 1.0) + +def broadcast_to(x, target): + """ + Broadcast x to the shape of target. + """ + shape = P.Shape() + if shape(x) == shape(target): + return x + return P.BroadcastTo(shape(target))(x) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 62ac681a27..8484fdeddc 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -19,13 +19,10 @@ from mindspore._checkparam import Validator as validator from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.ops import _utils as utils from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register import mindspore.nn as nn -import mindspore.nn.probability as msp - def cast_to_tensor(t, hint_type=mstype.float32): """ @@ -46,41 +43,13 @@ def cast_to_tensor(t, hint_type=mstype.float32): raise ValueError(f'Input cannot be None in cast_to_tensor') if isinstance(t, Parameter): return t - t_type = hint_type - if isinstance(t, Tensor): - # convert the type of tensor to dtype - return Tensor(t.asnumpy(), dtype=t_type) - if isinstance(t, (list, np.ndarray)): - return Tensor(t, dtype=t_type) if isinstance(t, bool): raise TypeError(f'Input cannot be Type Bool') - if isinstance(t, (int, float)): - return Tensor(t, dtype=t_type) + if isinstance(t, (Tensor, np.ndarray, list, int, float)): + return Tensor(t, dtype=hint_type) invalid_type = type(t) raise TypeError( - f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") - - -def convert_to_batch(t, batch_shape, required_type): - """ - Convert a Tensor to a given batch shape. - - Args: - t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted. - batch_shape (tuple): desired batch shape. - dtype (mindspore.dtype): desired dtype. - - Raises: - RuntimeError: if the converison cannot be done. - - Returns: - Tensor, with shape of batch_shape. - """ - if isinstance(t, Parameter): - return t - t = cast_to_tensor(t, required_type) - return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) - + f"Unable to convert input of type {invalid_type} to a Tensor of type {hint_type}") def cast_type_for_device(dtype): """ @@ -100,54 +69,6 @@ def cast_type_for_device(dtype): return dtype -def check_scalar_from_param(params): - """ - Check if params are all scalars. - - Args: - params (dict): parameters used to initialize distribution. - - Notes: String parameters are excluded. - """ - for value in params.values(): - if value is None: - continue - if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): - return params['distribution'].is_scalar_batch - if isinstance(value, Parameter): - return False - if not isinstance(value, (int, float, str, type(params['dtype']))): - return False - return True - - -def calc_broadcast_shape_from_param(params): - """ - Calculate the broadcast shape from params. - - Args: - params (dict): parameters used to initialize distribution. - - Returns: - tuple. - """ - broadcast_shape = [] - for value in params.values(): - if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): - return params['distribution'].broadcast_shape - if isinstance(value, (str, type(params['dtype']))): - continue - if value is None: - return None - if isinstance(value, Parameter): - value_t = value.data - else: - value_t = cast_to_tensor(value, mstype.float32) - broadcast_shape = utils.get_broadcast_shape( - broadcast_shape, list(value_t.shape), params['name']) - return tuple(broadcast_shape) - - def check_greater_equal_zero(value, name): """ Check if the given Tensor is greater zero. @@ -371,6 +292,9 @@ def set_param_type(args, hint_type): Raises: TypeError: if tensors in args are not the same dtype. """ + int_type = mstype.int_type + mstype.uint_type + if hint_type in int_type: + hint_type = mstype.float32 common_dtype = None for name, arg in args.items(): if hasattr(arg, 'dtype'): @@ -382,7 +306,6 @@ def set_param_type(args, hint_type): common_dtype = cur_dtype elif cur_dtype != common_dtype: raise TypeError(f"{name} should have the same dtype as other arguments.") - int_type = mstype.int_type + mstype.uint_type if common_dtype in int_type or common_dtype == mstype.float64: return mstype.float32 return hint_type if common_dtype is None else common_dtype diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 51ab3d7df3..120e9b6359 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype from mindspore.ops import operations as P from mindspore.ops import composite as C from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type +from ._utils.utils import check_prob, check_type, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -116,18 +116,14 @@ class Bernoulli(Distribution): Constructor of Bernoulli. """ param = dict(locals()) + param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Bernoulli, self).__init__(seed, dtype, name, param) - self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) - if probs is not None: - self._probs = cast_to_tensor(probs, self.parameter_type) - check_prob(self.probs) - else: - self._probs = probs - self.default_parameters = [self.probs] - self.parameter_names = ['probs1'] + self._probs = self._add_parameter(probs, 'probs') + if self._probs is not None: + check_prob(self.probs) # ops needed for the class self.exp = exp_generic @@ -135,14 +131,11 @@ class Bernoulli(Distribution): self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() - self.dtypeop = P.DType() self.floor = P.Floor() self.fill = P.Fill() self.less = P.Less() self.shape = P.Shape() self.select = P.Select() - self.sq = P.Square() - self.sqrt = P.Sqrt() self.uniform = C.uniform def extend_repr(self): @@ -173,9 +166,8 @@ class Bernoulli(Distribution): MODE(B) = 1 if probs1 > 0.5 else = 0 """ probs1 = self._check_param_type(probs1) - prob_type = self.dtypeop(probs1) - zeros = self.fill(prob_type, self.shape(probs1), 0.0) - ones = self.fill(prob_type, self.shape(probs1), 1.0) + zeros = self.fill(self.dtype, self.shape(probs1), 0.0) + ones = self.fill(self.dtype, self.shape(probs1), 1.0) comp = self.less(0.5, probs1) return self.select(comp, ones, zeros) @@ -244,13 +236,13 @@ class Bernoulli(Distribution): value = self.cast(value, self.parameter_type) value = self.floor(value) probs1 = self._check_param_type(probs1) - prob_type = self.dtypeop(probs1) - value = value * self.fill(prob_type, self.shape(probs1), 1.0) - probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) + broadcast_shape_tensor = value * probs1 + value = self.broadcast(value, broadcast_shape_tensor) + probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor) comp_zero = self.less(value, 0.0) comp_one = self.less(value, 1.0) - zeros = self.fill(prob_type, self.shape(value), 0.0) - ones = self.fill(prob_type, self.shape(value), 1.0) + zeros = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 0.0) + ones = self.fill(self.parameter_type, self.shape(broadcast_shape_tensor), 1.0) less_than_zero = self.select(comp_zero, zeros, probs0) return self.select(comp_one, less_than_zero, ones) diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index eb2719f15d..87a963ea6d 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -14,13 +14,14 @@ # ============================================================================ """basic""" from mindspore import context +from mindspore.ops import operations as P from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common import get_seed -from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ - raise_none_error +from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device from ._utils.utils import CheckTuple, CheckTensor +from ._utils.custom_ops import broadcast_to, exp_generic, log_generic class Distribution(Cell): @@ -68,14 +69,16 @@ class Distribution(Cell): self._seed = seed self._dtype = cast_type_for_device(dtype) self._parameters = {} + # parsing parameters for k in param.keys(): if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] + # some attributes - self._broadcast_shape = calc_broadcast_shape_from_param( - self.parameters) - self._is_scalar_batch = check_scalar_from_param(self.parameters) + self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) + self._broadcast_shape = self._calc_broadcast_shape() + self._is_scalar_batch = self._check_is_scalar_batch() # set the function to call according to the derived class's attributes self._set_prob() @@ -91,6 +94,18 @@ class Distribution(Cell): self.context_mode = context.get_context('mode') self.checktuple = CheckTuple() self.checktensor = CheckTensor() + self.broadcast = broadcast_to + + # ops needed for the base class + self.cast_base = P.Cast() + self.dtype_base = P.DType() + self.exp_base = exp_generic + self.fill_base = P.Fill() + self.log_base = log_generic + self.sametypeshape_base = P.SameTypeShape() + self.sq_base = P.Square() + self.sqrt_base = P.Sqrt() + self.shape_base = P.Shape() @property def name(self): @@ -116,6 +131,21 @@ class Distribution(Cell): def broadcast_shape(self): return self._broadcast_shape + def _add_parameter(self, value, name): + """ + Cast `value` to a tensor and add it to `self.default_parameters`. + Add `name` into and `self.parameter_names`. + """ + # initialize the attributes if they do not exist yet + if not hasattr(self, 'default_parameters'): + self.default_parameters = [] + self.parameter_names = [] + # cast value to a tensor if it is not None + value_t = None if value is None else cast_to_tensor(value, self.parameter_type) + self.default_parameters += [value_t,] + self.parameter_names += [name,] + return value_t + def _check_param_type(self, *args): """ Check the availability and validity of default parameters and `dist_spec_args`. @@ -123,6 +153,7 @@ class Distribution(Cell): are None, the parameters must be passed in through `args`. """ broadcast_shape = None + broadcast_shape_tensor = None common_dtype = None out = [] @@ -139,17 +170,17 @@ class Distribution(Cell): # broadcast if the number of args > 1 if broadcast_shape is None: - broadcast_shape = self.shape(arg) - common_dtype = self.dtypeop(arg) + broadcast_shape = self.shape_base(arg) + common_dtype = self.dtype_base(arg) + broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) else: - ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) - broadcast_shape = self.shape(arg + ones) - + broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) + broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) + arg = self.broadcast(arg, broadcast_shape_tensor) # check if the arguments have the same dtype - arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) - dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) - self.sametypeshape(arg, dtype_tensor) - arg = self.cast(arg, self.parameter_type) + self.sametypeshape_base(arg, broadcast_shape_tensor) + + arg = self.cast_base(arg, self.parameter_type) out.append(arg) if len(out) == 1: @@ -158,7 +189,7 @@ class Distribution(Cell): # broadcast all args to broadcast_shape result = () for arg in out: - arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) + arg = self.broadcast(arg, broadcast_shape_tensor) result = result + (arg,) return result @@ -171,6 +202,38 @@ class Distribution(Cell): return value return self.checktensor(value, name) + def _check_is_scalar_batch(self): + """ + Check if the parameters used during initialization are scalars. + """ + if hasattr(self, 'distribution'): + return self._distribution.is_scalar_batch + param_dict = self.parameters['param_dict'] + for value in param_dict.values(): + if value is None: + continue + if not isinstance(value, (int, float)): + return False + return True + + def _calc_broadcast_shape(self): + """ + Calculate the broadcast shape of the parameters used during initialization. + """ + if hasattr(self, 'distribution'): + return self._distribution.broadcast_shape + param_dict = self.parameters['param_dict'] + broadcast_shape_tensor = None + for value in param_dict.values(): + if value is None: + return None + if broadcast_shape_tensor is None: + broadcast_shape_tensor = cast_to_tensor(value) + else: + value = cast_to_tensor(value) + broadcast_shape_tensor = (value + broadcast_shape_tensor) + return broadcast_shape_tensor.shape + def _set_prob(self): """ Set probability funtion based on the availability of `_prob` and `_log_likehood`. @@ -280,7 +343,7 @@ class Distribution(Cell): .. math:: probability(x) = \exp(log_likehood(x)) """ - return self.exp(self._log_prob(value, *args, **kwargs)) + return self.exp_base(self._log_prob(value, *args, **kwargs)) def prob(self, value, *args, **kwargs): """ @@ -304,7 +367,7 @@ class Distribution(Cell): .. math:: log_prob(x) = \log(prob(x)) """ - return self.log(self._prob(value, *args, **kwargs)) + return self.log_base(self._prob(value, *args, **kwargs)) def cdf(self, value, *args, **kwargs): """ @@ -328,7 +391,7 @@ class Distribution(Cell): .. math:: cdf(x) = \exp(log_cdf(x)) """ - return self.exp(self._log_cdf(value, *args, **kwargs)) + return self.exp_base(self._log_cdf(value, *args, **kwargs)) def _calc_cdf_from_survival(self, value, *args, **kwargs): r""" @@ -346,7 +409,7 @@ class Distribution(Cell): .. math:: cdf(x) = 1 - (\exp(log_survival(x))) """ - return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) + return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs)) def log_cdf(self, value, *args, **kwargs): """ @@ -370,7 +433,7 @@ class Distribution(Cell): .. math:: log_cdf(x) = \log(cdf(x)) """ - return self.log(self._call_cdf(value, *args, **kwargs)) + return self.log_base(self._call_cdf(value, *args, **kwargs)) def survival_function(self, value, *args, **kwargs): """ @@ -403,7 +466,7 @@ class Distribution(Cell): .. math:: survival(x) = \exp(survival_function(x)) """ - return self.exp(self._log_survival(value, *args, **kwargs)) + return self.exp_base(self._log_survival(value, *args, **kwargs)) def log_survival(self, value, *args, **kwargs): """ @@ -427,7 +490,7 @@ class Distribution(Cell): .. math:: log_survival(x) = \log(survival_function(x)) """ - return self.log(self._call_survival(value, *args, **kwargs)) + return self.log_base(self._call_survival(value, *args, **kwargs)) def kl_loss(self, dist, *args, **kwargs): """ @@ -507,7 +570,7 @@ class Distribution(Cell): .. math:: STD(x) = \sqrt(VAR(x)) """ - return self.sqrt(self._var(*args, **kwargs)) + return self.sqrt_base(self._var(*args, **kwargs)) def _calc_var_from_sd(self, *args, **kwargs): r""" @@ -516,7 +579,7 @@ class Distribution(Cell): .. math:: VAR(x) = STD(x) ^ 2 """ - return self.sq(self._sd(*args, **kwargs)) + return self.sq_base(self._sd(*args, **kwargs)) def entropy(self, *args, **kwargs): """ diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 8b2f3aa83e..3edc040f31 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type +from ._utils.utils import check_greater_zero, check_type, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -118,18 +118,14 @@ class Exponential(Distribution): Constructor of Exponential. """ param = dict(locals()) + param['param_dict'] = {'rate': rate} valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Exponential, self).__init__(seed, dtype, name, param) - self.parameter_type = set_param_type({'rate': rate}, self.dtype) - if rate is not None: - self._rate = cast_to_tensor(rate, self.parameter_type) - check_greater_zero(self._rate, "rate") - else: - self._rate = rate - self.default_parameters = [self.rate] - self.parameter_names = ['rate'] + self._rate = self._add_parameter(rate, 'rate') + if self.rate is not None: + check_greater_zero(self.rate, 'rate') self.minval = np.finfo(np.float).tiny @@ -144,8 +140,6 @@ class Exponential(Distribution): self.less = P.Less() self.select = P.Select() self.shape = P.Shape() - self.sqrt = P.Sqrt() - self.sq = P.Square() self.uniform = C.uniform def extend_repr(self): diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 4087949c15..9a37d308c9 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -18,8 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ - set_param_type +from ._utils.utils import check_prob, check_type, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -120,18 +119,14 @@ class Geometric(Distribution): Constructor of Geometric distribution. """ param = dict(locals()) + param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Geometric, self).__init__(seed, dtype, name, param) - self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) - if probs is not None: - self._probs = cast_to_tensor(probs, self.parameter_type) - check_prob(self._probs) - else: - self._probs = probs - self.default_parameters = [self.probs] - self.parameter_names = ['probs1'] + self._probs = self._add_parameter(probs, 'probs') + if self._probs is not None: + check_prob(self.probs) self.minval = np.finfo(np.float).tiny @@ -150,7 +145,6 @@ class Geometric(Distribution): self.select = P.Select() self.shape = P.Shape() self.sq = P.Square() - self.sqrt = P.Sqrt() self.uniform = C.uniform def extend_repr(self): @@ -181,7 +175,7 @@ class Geometric(Distribution): MODE(Geo) = 0 """ probs1 = self._check_param_type(probs1) - return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) + return self.fill(self.dtype, self.shape(probs1), 0.) def _var(self, probs1=None): r""" @@ -229,7 +223,7 @@ class Geometric(Distribution): value = self.floor(value) probs1 = self._check_param_type(probs1) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) - zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) + zeros = self.fill(self.dtypeop(pmf), self.shape(pmf), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, pmf) @@ -252,7 +246,7 @@ class Geometric(Distribution): probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 cdf = 1.0 - self.pow(probs0, value + 1.0) - zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) + zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, cdf) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 7d4da39b6a..0df0d2b8e4 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -18,8 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ - set_param_type +from ._utils.utils import check_greater_zero, check_type, check_distribution_name from ._utils.custom_ops import exp_generic, expm1_generic, log_generic @@ -125,23 +124,15 @@ class Normal(Distribution): Constructor of Normal. """ param = dict(locals()) + param['param_dict'] = {'mean': mean, 'sd': sd} valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Normal, self).__init__(seed, dtype, name, param) - self.parameter_type = set_param_type( - {'mean': mean, 'sd': sd}, self.dtype) - if mean is not None and sd is not None: - self._mean_value = cast_to_tensor(mean, self.parameter_type) - self._sd_value = cast_to_tensor(sd, self.parameter_type) - check_greater_zero(self._sd_value, "Standard deviation") - else: - self._mean_value = mean if mean is None else cast_to_tensor( - mean, self.parameter_type) - self._sd_value = sd if sd is None else cast_to_tensor( - sd, self.parameter_type) - self.default_parameters = [self._mean_value, self._sd_value] - self.parameter_names = ['mean', 'sd'] + self._mean_value = self._add_parameter(mean, 'mean') + self._sd_value = self._add_parameter(sd, 'sd') + if self._sd_value is not None: + check_greater_zero(self._sd_value, "Standard deviation") # ops needed for the class self.exp = exp_generic @@ -151,13 +142,9 @@ class Normal(Distribution): self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() - self.fill = P.Fill() self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() - self.zeroslike = P.ZerosLike() - self.dtypeop = P.DType() - self.sametypeshape = P.SameTypeShape() def extend_repr(self): if self.is_scalar_batch: diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index dd0bacc20e..c47ba7db0b 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -81,6 +81,8 @@ class TransformedDistribution(Distribution): self._bijector = bijector self._distribution = distribution self._is_linear_transformation = bijector.is_constant_jacobian + self.default_parameters = distribution.default_parameters + self.parameter_names = distribution.parameter_names self.exp = exp_generic self.log = log_generic diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 6668161cd5..21a1754bac 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -17,8 +17,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ - set_param_type +from ._utils.utils import check_greater, check_type, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -124,23 +123,16 @@ class Uniform(Distribution): Constructor of Uniform distribution. """ param = dict(locals()) + param['param_dict'] = {'low': low, 'high': high} valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Uniform, self).__init__(seed, dtype, name, param) - self.parameter_type = set_param_type( - {'low': low, 'high': high}, self.dtype) - if low is not None and high is not None: - self._low = cast_to_tensor(low, self.parameter_type) - self._high = cast_to_tensor(high, self.parameter_type) - check_greater(self.low, self.high, "low value", "high value") - else: - self._low = low if low is None else cast_to_tensor( - low, self.parameter_type) - self._high = high if high is None else cast_to_tensor( - high, self.parameter_type) - self.default_parameters = [self.low, self.high] - self.parameter_names = ['low', 'high'] + self._low = self._add_parameter(low, 'low') + self._high = self._add_parameter(high, 'high') + if self.low is not None and self.high is not None: + check_greater(self.low, self.high, 'low', 'high') + # ops needed for the class self.exp = exp_generic @@ -156,12 +148,9 @@ class Uniform(Distribution): self.select = P.Select() self.shape = P.Shape() self.sq = P.Square() - self.sqrt = P.Sqrt() self.zeroslike = P.ZerosLike() self.uniform = C.uniform - self.sametypeshape = P.SameTypeShape() - def extend_repr(self): if self.is_scalar_batch: str_info = f'low = {self.low}, high = {self.high}'