diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index 1e36ae0906..e59d8d3154 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================ """Bijector""" +from mindspore import context from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator +from ..distribution._utils.utils import CheckTensor from ..distribution import Distribution from ..distribution import TransformedDistribution @@ -40,7 +42,7 @@ class Bijector(Cell): Constructor of bijector class. """ super(Bijector, self).__init__() - validator.check_value_type('name', name, [str], 'Bijector') + validator.check_value_type('name', name, [str], type(self).__name__) validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name) self._name = name @@ -53,6 +55,9 @@ class Bijector(Cell): self._is_constant_jacobian = is_constant_jacobian self._is_injective = is_injective + self.context_mode = context.get_context('mode') + self.checktensor = CheckTensor() + @property def name(self): return self._name @@ -73,6 +78,15 @@ class Bijector(Cell): def is_injective(self): return self._is_injective + def _check_value(self, value, name): + """ + Check availability fo value as a Tensor. + """ + if self.context_mode == 0: + self.checktensor(value, name) + return value + return self.checktensor(value, name) + def forward(self, *args, **kwargs): """ Forward transformation: transform the input value to another distribution. diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index 279de16dc6..6550fae036 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -16,7 +16,6 @@ from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel -from ..distribution._utils.utils import CheckTensor from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic from .bijector import Bijector @@ -66,8 +65,6 @@ class PowerTransform(Bijector): self.log = log_generic self.log1p = log1p_generic - self.checktensor = CheckTensor() - @property def power(self): return self._power @@ -80,13 +77,13 @@ class PowerTransform(Bijector): return shape def _forward(self, x): - self.checktensor(x, 'value') + x = self._check_value(x, 'value') if self.power == 0: return self.exp(x) return self.exp(self.log1p(x * self.power) / self.power) def _inverse(self, y): - self.checktensor(y, 'value') + y = self._check_value(y, 'value') if self.power == 0: return self.log(y) return self.expm1(self.log(y) * self.power) / self.power @@ -103,7 +100,7 @@ class PowerTransform(Bijector): f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) """ - self.checktensor(x, 'value') + x = self._check_value(x, 'value') if self.power == 0: return x return (1. / self.power - 1) * self.log1p(x * self.power) @@ -120,5 +117,5 @@ class PowerTransform(Bijector): f'(x) = \frac{e^c\log(y)}{y} \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) """ - self.checktensor(y, 'value') + y = self._check_value(y, 'value') return (self.power - 1) * self.log(y) diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index b6d079fef1..b75298be20 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -15,7 +15,7 @@ """Scalar Affine Bijector""" from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor, CheckTensor +from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.custom_ops import log_generic from .bijector import Bijector @@ -57,8 +57,8 @@ class ScalarAffine(Bijector): Constructor of scalar affine bijector. """ param = dict(locals()) - validator.check_value_type('scale', scale, [int, float], name) - validator.check_value_type('shift', shift, [int, float], name) + validator.check_value_type('scale', scale, [int, float], type(self).__name__) + validator.check_value_type('shift', shift, [int, float], type(self).__name__) self._scale = cast_to_tensor(scale) self._shift = cast_to_tensor(shift) super(ScalarAffine, self).__init__( @@ -71,8 +71,6 @@ class ScalarAffine(Bijector): self.abs = P.Abs() self.log = log_generic - self.checktensor = CheckTensor() - @property def scale(self): return self._scale @@ -93,7 +91,7 @@ class ScalarAffine(Bijector): .. math:: f(x) = a * x + b """ - self.checktensor(x, 'value') + x = self._check_value(x, 'value') return self.scale * x + self.shift def _inverse(self, y): @@ -101,7 +99,7 @@ class ScalarAffine(Bijector): .. math:: f(y) = \frac{y - b}{a} """ - self.checktensor(y, 'value') + y = self._check_value(y, 'value') return (y - self.shift) / self.scale def _forward_log_jacobian(self, x): @@ -111,7 +109,7 @@ class ScalarAffine(Bijector): f'(x) = a \log(f'(x)) = \log(a) """ - self.checktensor(x, 'value') + x = self._check_value(x, 'value') return self.log(self.abs(self.scale)) def _inverse_log_jacobian(self, y): @@ -121,5 +119,5 @@ class ScalarAffine(Bijector): f'(x) = \frac{1.0}{a} \log(f'(x)) = - \log(a) """ - self.checktensor(y, 'value') + y = self._check_value(y, 'value') return -1. * self.log(self.abs(self.scale)) diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 1d710ae480..6ed8695820 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.common import dtype as mstype from mindspore.nn.layer.activation import LogSigmoid from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor, CheckTensor +from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic from .bijector import Bijector @@ -57,7 +57,7 @@ class Softplus(Bijector): sharpness=1.0, name='Softplus'): param = dict(locals()) - validator.check_value_type('sharpness', sharpness, [int, float], name) + validator.check_value_type('sharpness', sharpness, [int, float], type(self).__name__) super(Softplus, self).__init__(name=name, param=param) self._sharpness = cast_to_tensor(sharpness) @@ -76,7 +76,6 @@ class Softplus(Bijector): self.softplus = self._softplus self.inverse_softplus = self._inverse_softplus - self.checktensor = CheckTensor() self.threshold = np.log(np.finfo(np.float32).eps) + 1 self.tiny = np.exp(self.threshold) @@ -119,7 +118,7 @@ class Softplus(Bijector): return shape def _forward(self, x): - self.checktensor(x, 'value') + x = self._check_value(x, 'value') scaled_value = self.sharpness * x return self.softplus(scaled_value) / self.sharpness @@ -129,7 +128,7 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{kx}))}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} """ - self.checktensor(y, 'value') + y = self._check_value(y, 'value') scaled_value = self.sharpness * y return self.inverse_softplus(scaled_value) / self.sharpness @@ -140,7 +139,7 @@ class Softplus(Bijector): f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) """ - self.checktensor(x, 'value') + x = self._check_value(x, 'value') scaled_value = self.sharpness * x return self.log_sigmoid(scaled_value) @@ -151,6 +150,6 @@ class Softplus(Bijector): f'(y) = \frac{e^{ky}}{e^{ky} - 1} \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) """ - self.checktensor(y, 'value') + y = self._check_value(y, 'value') scaled_value = self.sharpness * y return scaled_value - self.inverse_softplus(scaled_value) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index d49088ad48..90f09fe41c 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer): # Pynative mode if isinstance(x, tuple): return x - raise TypeError(f"For {name['value']}, Input type should b a tuple.") + raise TypeError(f"For {name}, input type should be a tuple.") class CheckTensor(PrimitiveWithInfer): @@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer): return out def __call__(self, x, name): - return + if isinstance(x, Tensor): + return x + raise TypeError(f"For {name}, input type should be a Tensor.") diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 018f7dd2b0..cd03256d4d 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -99,7 +99,7 @@ class Bernoulli(Distribution): """ param = dict(locals()) valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - check_type(dtype, valid_dtype, "Bernoulli") + check_type(dtype, valid_dtype, type(self).__name__) super(Bernoulli, self).__init__(seed, dtype, name, param) self.parameter_type = mstype.float32 if probs is not None: @@ -144,7 +144,10 @@ class Bernoulli(Distribution): Check availablity of distribution specific args probs1. """ if probs1 is not None: - self.checktensor(probs1, 'probs1') + if self.context_mode == 0: + self.checktensor(probs1, 'probs1') + else: + probs1 = self.checktensor(probs1, 'probs1') return self.cast(probs1, self.parameter_type) return self.probs if self.probs is not None else raise_none_error('probs1') @@ -210,7 +213,7 @@ class Bernoulli(Distribution): pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 @@ -229,7 +232,7 @@ class Bernoulli(Distribution): cdf(k) = probs0 if 0 <= k <1; cdf(k) = 1 if k >=1; """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) probs1 = self._check_param(probs1) @@ -257,7 +260,7 @@ class Bernoulli(Distribution): probs0_a * \log(\frac{probs0_a}{probs0_b}) """ check_distribution_name(dist, 'Bernoulli') - self.checktensor(probs1_b, 'probs1_b') + probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) probs1_a = self._check_param(probs1) probs0_a = 1.0 - probs1_a diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 203c97fcc6..9514198dfd 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """basic""" +from mindspore import context from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel @@ -54,7 +55,7 @@ class Distribution(Cell): Constructor of distribution class. """ super(Distribution, self).__init__() - validator.check_value_type('name', name, [str], 'distribution_name') + validator.check_value_type('name', name, [str], type(self).__name__) validator.check_integer('seed', seed, 0, Rel.GE, name) self._name = name @@ -81,6 +82,7 @@ class Distribution(Cell): self._set_log_survival() self._set_cross_entropy() + self.context_mode = context.get_context('mode') self.checktuple = CheckTuple() self.checktensor = CheckTensor() @@ -108,6 +110,15 @@ class Distribution(Cell): def broadcast_shape(self): return self._broadcast_shape + def _check_value(self, value, name): + """ + Check availability fo value as a Tensor. + """ + if self.context_mode == 0: + self.checktensor(value, name) + return value + return self.checktensor(value, name) + def _set_prob(self): """ Set probability funtion based on the availability of _prob and _log_likehood. diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 2d0488b105..43886ec096 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -100,7 +100,7 @@ class Exponential(Distribution): """ param = dict(locals()) valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, "Exponential") + check_type(dtype, valid_dtype, type(self).__name__) super(Exponential, self).__init__(seed, dtype, name, param) self.parameter_type = dtype if rate is not None: @@ -146,7 +146,10 @@ class Exponential(Distribution): Check availablity of distribution specific args rate. """ if rate is not None: - self.checktensor(rate, 'rate') + if self.context_mode == 0: + self.checktensor(rate, 'rate') + else: + rate = self.checktensor(rate, 'rate') return self.cast(rate, self.parameter_type) return self.rate if self.rate is not None else raise_none_error('rate') @@ -210,7 +213,7 @@ class Exponential(Distribution): .. math:: pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 """ - self.checktensor(value, "value") + value = self._check_value(value, "value") value = self.cast(value, self.dtype) rate = self._check_param(rate) prob = self.exp(self.log(rate) - rate * value) @@ -232,7 +235,7 @@ class Exponential(Distribution): .. math:: cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, self.dtype) rate = self._check_param(rate) cdf = 1.0 - self.exp(-1. * rate * value) @@ -251,7 +254,7 @@ class Exponential(Distribution): rate_a (Tensor): rate of distribution a. Default: self.rate. """ check_distribution_name(dist, 'Exponential') - self.checktensor(rate_b, 'rate_b') + rate_b = self._check_value(rate_b, 'rate_b') rate_b = self.cast(rate_b, self.parameter_type) rate_a = self._check_param(rate) return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index a15b5f37e6..230f7a9174 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -102,7 +102,7 @@ class Geometric(Distribution): """ param = dict(locals()) valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - check_type(dtype, valid_dtype, "Geometric") + check_type(dtype, valid_dtype, type(self).__name__) super(Geometric, self).__init__(seed, dtype, name, param) self.parameter_type = mstype.float32 if probs is not None: @@ -150,7 +150,10 @@ class Geometric(Distribution): Check availablity of distribution specific args probs1. """ if probs1 is not None: - self.checktensor(probs1, 'probs1') + if self.context_mode == 0: + self.checktensor(probs1, 'probs1') + else: + probs1 = self.checktensor(probs1, 'probs1') return self.cast(probs1, self.parameter_type) return self.probs if self.probs is not None else raise_none_error('probs1') @@ -211,7 +214,7 @@ class Geometric(Distribution): pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = 0 if k < 0. """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) probs1 = self._check_param(probs1) @@ -233,7 +236,7 @@ class Geometric(Distribution): cdf(k) = 0 if k < 0. """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) probs1 = self._check_param(probs1) @@ -256,7 +259,7 @@ class Geometric(Distribution): KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) """ check_distribution_name(dist, 'Geometric') - self.checktensor(probs1_b, 'probs1_b') + probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) probs1_a = self._check_param(probs1) probs0_a = 1.0 - probs1_a diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 7005876229..1ad5b4dff7 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ +from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ raise_none_error from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic @@ -102,12 +102,12 @@ class Normal(Distribution): """ param = dict(locals()) valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, "Normal") + check_type(dtype, valid_dtype, type(self).__name__) super(Normal, self).__init__(seed, dtype, name, param) self.parameter_type = dtype if mean is not None and sd is not None: - self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type) - self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type) + self._mean_value = cast_to_tensor(mean, self.parameter_type) + self._sd_value = cast_to_tensor(sd, self.parameter_type) check_greater_zero(self._sd_value, "Standard deviation") else: self._mean_value = mean @@ -139,12 +139,18 @@ class Normal(Distribution): Check availablity of distribution specific args mean and sd. """ if mean is not None: - self.checktensor(mean, 'mean') + if self.context_mode == 0: + self.checktensor(mean, 'mean') + else: + mean = self.checktensor(mean, 'mean') mean = self.cast(mean, self.parameter_type) else: mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') if sd is not None: - self.checktensor(sd, 'sd') + if self.context_mode == 0: + self.checktensor(sd, 'sd') + else: + sd = self.checktensor(sd, 'sd') sd = self.cast(sd, self.parameter_type) else: sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') @@ -210,7 +216,7 @@ class Normal(Distribution): .. math:: L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, self.dtype) mean, sd = self._check_param(mean, sd) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) @@ -229,7 +235,7 @@ class Normal(Distribution): .. math:: cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, self.dtype) mean, sd = self._check_param(mean, sd) sqrt2 = self.sqrt(self.const(2.0)) @@ -252,8 +258,8 @@ class Normal(Distribution): 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) """ check_distribution_name(dist, 'Normal') - self.checktensor(mean_b, 'mean_b') - self.checktensor(sd_b, 'sd_b') + mean_b = self._check_value(mean_b, 'mean_b') + sd_b = self._check_value(sd_b, 'sd_b') mean_b = self.cast(mean_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type) mean_a, sd_a = self._check_param(mean, sd) diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 37c14a6e53..e641bd1c0f 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -46,10 +46,10 @@ class TransformedDistribution(Distribution): Constructor of transformed_distribution class. """ param = dict(locals()) - validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name) - validator.check_value_type('distribution', distribution, [Distribution], name) + validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__) + validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) valid_dtype = mstype.number_type - check_type(dtype, valid_dtype, "transformed_distribution") + check_type(dtype, valid_dtype, type(self).__name__) super(TransformedDistribution, self).__init__(seed, dtype, name, param) self._bijector = bijector diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 7f623ddb74..d2882d70fe 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ +from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ raise_none_error from ._utils.custom_ops import exp_generic, log_generic @@ -101,12 +101,12 @@ class Uniform(Distribution): """ param = dict(locals()) valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, "Uniform") + check_type(dtype, valid_dtype, type(self).__name__) super(Uniform, self).__init__(seed, dtype, name, param) self.parameter_type = dtype if low is not None and high is not None: - self._low = convert_to_batch(low, self.broadcast_shape, dtype) - self._high = convert_to_batch(high, self.broadcast_shape, dtype) + self._low = cast_to_tensor(low, dtype) + self._high = cast_to_tensor(high, dtype) check_greater(self.low, self.high, "low value", "high value") else: self._low = low @@ -142,12 +142,18 @@ class Uniform(Distribution): Check availablity of distribution specific args low and high. """ if low is not None: - self.checktensor(low, 'low') + if self.context_mode == 0: + self.checktensor(low, 'low') + else: + low = self.checktensor(low, 'low') low = self.cast(low, self.parameter_type) else: low = self.low if self.low is not None else raise_none_error('low') if high is not None: - self.checktensor(high, 'high') + if self.context_mode == 0: + self.checktensor(high, 'high') + else: + high = self.checktensor(high, 'high') high = self.cast(high, self.parameter_type) else: high = self.high if self.high is not None else raise_none_error('high') @@ -231,7 +237,7 @@ class Uniform(Distribution): pdf(x) = \frac{1.0}{high -low} if low <= x <= high; pdf(x) = 0 if x > high; """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, self.dtype) low, high = self._check_param(low, high) neg_ones = self.fill(self.dtype, self.shape(value), -1.0) @@ -255,9 +261,9 @@ class Uniform(Distribution): high_a (Tensor): upper bound of distribution a. Default: self.high. """ check_distribution_name(dist, 'Uniform') - self.checktensor(low_b, 'low_b') + low_b = self._check_value(low_b, 'low_b') low_b = self.cast(low_b, self.parameter_type) - self.checktensor(high_b, 'high_b') + high_b = self._check_value(high_b, 'high_b') high_b = self.cast(high_b, self.parameter_type) low_a, high_a = self._check_param(low, high) kl = self.log(high_b - low_b) - self.log(high_a - low_a) @@ -278,7 +284,7 @@ class Uniform(Distribution): cdf(x) = \frac{x - low}{high -low} if low <= x <= high; cdf(x) = 1 if x > high; """ - self.checktensor(value, 'value') + value = self._check_value(value, 'value') value = self.cast(value, self.dtype) low, high = self._check_param(low, high) prob = (value - low) / (high - low)