Merge pull request !5193 from XunDeng/pp_issue_branchtags/v1.0.0
| @@ -13,8 +13,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Bijector""" | |||
| from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import CheckTensor | |||
| from ..distribution import Distribution | |||
| from ..distribution import TransformedDistribution | |||
| @@ -40,7 +42,7 @@ class Bijector(Cell): | |||
| Constructor of bijector class. | |||
| """ | |||
| super(Bijector, self).__init__() | |||
| validator.check_value_type('name', name, [str], 'Bijector') | |||
| validator.check_value_type('name', name, [str], type(self).__name__) | |||
| validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name) | |||
| validator.check_value_type('is_injective', is_injective, [bool], name) | |||
| self._name = name | |||
| @@ -53,6 +55,9 @@ class Bijector(Cell): | |||
| self._is_constant_jacobian = is_constant_jacobian | |||
| self._is_injective = is_injective | |||
| self.context_mode = context.get_context('mode') | |||
| self.checktensor = CheckTensor() | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| @@ -73,6 +78,15 @@ class Bijector(Cell): | |||
| def is_injective(self): | |||
| return self._is_injective | |||
| def _check_value(self, value, name): | |||
| """ | |||
| Check availability fo value as a Tensor. | |||
| """ | |||
| if self.context_mode == 0: | |||
| self.checktensor(value, name) | |||
| return value | |||
| return self.checktensor(value, name) | |||
| def forward(self, *args, **kwargs): | |||
| """ | |||
| Forward transformation: transform the input value to another distribution. | |||
| @@ -16,7 +16,6 @@ | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from ..distribution._utils.utils import CheckTensor | |||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | |||
| from .bijector import Bijector | |||
| @@ -66,8 +65,6 @@ class PowerTransform(Bijector): | |||
| self.log = log_generic | |||
| self.log1p = log1p_generic | |||
| self.checktensor = CheckTensor() | |||
| @property | |||
| def power(self): | |||
| return self._power | |||
| @@ -80,13 +77,13 @@ class PowerTransform(Bijector): | |||
| return shape | |||
| def _forward(self, x): | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| if self.power == 0: | |||
| return self.exp(x) | |||
| return self.exp(self.log1p(x * self.power) / self.power) | |||
| def _inverse(self, y): | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| if self.power == 0: | |||
| return self.log(y) | |||
| return self.expm1(self.log(y) * self.power) / self.power | |||
| @@ -103,7 +100,7 @@ class PowerTransform(Bijector): | |||
| f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} | |||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | |||
| """ | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| if self.power == 0: | |||
| return x | |||
| return (1. / self.power - 1) * self.log1p(x * self.power) | |||
| @@ -120,5 +117,5 @@ class PowerTransform(Bijector): | |||
| f'(x) = \frac{e^c\log(y)}{y} | |||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | |||
| """ | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| return (self.power - 1) * self.log(y) | |||
| @@ -15,7 +15,7 @@ | |||
| """Scalar Affine Bijector""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from ..distribution._utils.custom_ops import log_generic | |||
| from .bijector import Bijector | |||
| @@ -57,8 +57,8 @@ class ScalarAffine(Bijector): | |||
| Constructor of scalar affine bijector. | |||
| """ | |||
| param = dict(locals()) | |||
| validator.check_value_type('scale', scale, [int, float], name) | |||
| validator.check_value_type('shift', shift, [int, float], name) | |||
| validator.check_value_type('scale', scale, [int, float], type(self).__name__) | |||
| validator.check_value_type('shift', shift, [int, float], type(self).__name__) | |||
| self._scale = cast_to_tensor(scale) | |||
| self._shift = cast_to_tensor(shift) | |||
| super(ScalarAffine, self).__init__( | |||
| @@ -71,8 +71,6 @@ class ScalarAffine(Bijector): | |||
| self.abs = P.Abs() | |||
| self.log = log_generic | |||
| self.checktensor = CheckTensor() | |||
| @property | |||
| def scale(self): | |||
| return self._scale | |||
| @@ -93,7 +91,7 @@ class ScalarAffine(Bijector): | |||
| .. math:: | |||
| f(x) = a * x + b | |||
| """ | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| return self.scale * x + self.shift | |||
| def _inverse(self, y): | |||
| @@ -101,7 +99,7 @@ class ScalarAffine(Bijector): | |||
| .. math:: | |||
| f(y) = \frac{y - b}{a} | |||
| """ | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| return (y - self.shift) / self.scale | |||
| def _forward_log_jacobian(self, x): | |||
| @@ -111,7 +109,7 @@ class ScalarAffine(Bijector): | |||
| f'(x) = a | |||
| \log(f'(x)) = \log(a) | |||
| """ | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| return self.log(self.abs(self.scale)) | |||
| def _inverse_log_jacobian(self, y): | |||
| @@ -121,5 +119,5 @@ class ScalarAffine(Bijector): | |||
| f'(x) = \frac{1.0}{a} | |||
| \log(f'(x)) = - \log(a) | |||
| """ | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| return -1. * self.log(self.abs(self.scale)) | |||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn.layer.activation import LogSigmoid | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | |||
| from ..distribution._utils.utils import cast_to_tensor | |||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic | |||
| from .bijector import Bijector | |||
| @@ -57,7 +57,7 @@ class Softplus(Bijector): | |||
| sharpness=1.0, | |||
| name='Softplus'): | |||
| param = dict(locals()) | |||
| validator.check_value_type('sharpness', sharpness, [int, float], name) | |||
| validator.check_value_type('sharpness', sharpness, [int, float], type(self).__name__) | |||
| super(Softplus, self).__init__(name=name, param=param) | |||
| self._sharpness = cast_to_tensor(sharpness) | |||
| @@ -76,7 +76,6 @@ class Softplus(Bijector): | |||
| self.softplus = self._softplus | |||
| self.inverse_softplus = self._inverse_softplus | |||
| self.checktensor = CheckTensor() | |||
| self.threshold = np.log(np.finfo(np.float32).eps) + 1 | |||
| self.tiny = np.exp(self.threshold) | |||
| @@ -119,7 +118,7 @@ class Softplus(Bijector): | |||
| return shape | |||
| def _forward(self, x): | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.softplus(scaled_value) / self.sharpness | |||
| @@ -129,7 +128,7 @@ class Softplus(Bijector): | |||
| f(x) = \frac{\log(1 + e^{kx}))}{k} | |||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| """ | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return self.inverse_softplus(scaled_value) / self.sharpness | |||
| @@ -140,7 +139,7 @@ class Softplus(Bijector): | |||
| f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} | |||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | |||
| """ | |||
| self.checktensor(x, 'value') | |||
| x = self._check_value(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.log_sigmoid(scaled_value) | |||
| @@ -151,6 +150,6 @@ class Softplus(Bijector): | |||
| f'(y) = \frac{e^{ky}}{e^{ky} - 1} | |||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | |||
| """ | |||
| self.checktensor(y, 'value') | |||
| y = self._check_value(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return scaled_value - self.inverse_softplus(scaled_value) | |||
| @@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer): | |||
| # Pynative mode | |||
| if isinstance(x, tuple): | |||
| return x | |||
| raise TypeError(f"For {name['value']}, Input type should b a tuple.") | |||
| raise TypeError(f"For {name}, input type should be a tuple.") | |||
| class CheckTensor(PrimitiveWithInfer): | |||
| @@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer): | |||
| return out | |||
| def __call__(self, x, name): | |||
| return | |||
| if isinstance(x, Tensor): | |||
| return x | |||
| raise TypeError(f"For {name}, input type should be a Tensor.") | |||
| @@ -99,7 +99,7 @@ class Bernoulli(Distribution): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, "Bernoulli") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| if probs is not None: | |||
| @@ -144,7 +144,10 @@ class Bernoulli(Distribution): | |||
| Check availablity of distribution specific args probs1. | |||
| """ | |||
| if probs1 is not None: | |||
| self.checktensor(probs1, 'probs1') | |||
| if self.context_mode == 0: | |||
| self.checktensor(probs1, 'probs1') | |||
| else: | |||
| probs1 = self.checktensor(probs1, 'probs1') | |||
| return self.cast(probs1, self.parameter_type) | |||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||
| @@ -210,7 +213,7 @@ class Bernoulli(Distribution): | |||
| pmf(k) = probs1 if k = 1; | |||
| pmf(k) = probs0 if k = 0; | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| probs1 = self._check_param(probs1) | |||
| probs0 = 1.0 - probs1 | |||
| @@ -229,7 +232,7 @@ class Bernoulli(Distribution): | |||
| cdf(k) = probs0 if 0 <= k <1; | |||
| cdf(k) = 1 if k >=1; | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| @@ -257,7 +260,7 @@ class Bernoulli(Distribution): | |||
| probs0_a * \log(\frac{probs0_a}{probs0_b}) | |||
| """ | |||
| check_distribution_name(dist, 'Bernoulli') | |||
| self.checktensor(probs1_b, 'probs1_b') | |||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | |||
| probs1_b = self.cast(probs1_b, self.parameter_type) | |||
| probs1_a = self._check_param(probs1) | |||
| probs0_a = 1.0 - probs1_a | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """basic""" | |||
| from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| @@ -54,7 +55,7 @@ class Distribution(Cell): | |||
| Constructor of distribution class. | |||
| """ | |||
| super(Distribution, self).__init__() | |||
| validator.check_value_type('name', name, [str], 'distribution_name') | |||
| validator.check_value_type('name', name, [str], type(self).__name__) | |||
| validator.check_integer('seed', seed, 0, Rel.GE, name) | |||
| self._name = name | |||
| @@ -81,6 +82,7 @@ class Distribution(Cell): | |||
| self._set_log_survival() | |||
| self._set_cross_entropy() | |||
| self.context_mode = context.get_context('mode') | |||
| self.checktuple = CheckTuple() | |||
| self.checktensor = CheckTensor() | |||
| @@ -108,6 +110,15 @@ class Distribution(Cell): | |||
| def broadcast_shape(self): | |||
| return self._broadcast_shape | |||
| def _check_value(self, value, name): | |||
| """ | |||
| Check availability fo value as a Tensor. | |||
| """ | |||
| if self.context_mode == 0: | |||
| self.checktensor(value, name) | |||
| return value | |||
| return self.checktensor(value, name) | |||
| def _set_prob(self): | |||
| """ | |||
| Set probability funtion based on the availability of _prob and _log_likehood. | |||
| @@ -100,7 +100,7 @@ class Exponential(Distribution): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, "Exponential") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Exponential, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = dtype | |||
| if rate is not None: | |||
| @@ -146,7 +146,10 @@ class Exponential(Distribution): | |||
| Check availablity of distribution specific args rate. | |||
| """ | |||
| if rate is not None: | |||
| self.checktensor(rate, 'rate') | |||
| if self.context_mode == 0: | |||
| self.checktensor(rate, 'rate') | |||
| else: | |||
| rate = self.checktensor(rate, 'rate') | |||
| return self.cast(rate, self.parameter_type) | |||
| return self.rate if self.rate is not None else raise_none_error('rate') | |||
| @@ -210,7 +213,7 @@ class Exponential(Distribution): | |||
| .. math:: | |||
| pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 | |||
| """ | |||
| self.checktensor(value, "value") | |||
| value = self._check_value(value, "value") | |||
| value = self.cast(value, self.dtype) | |||
| rate = self._check_param(rate) | |||
| prob = self.exp(self.log(rate) - rate * value) | |||
| @@ -232,7 +235,7 @@ class Exponential(Distribution): | |||
| .. math:: | |||
| cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| rate = self._check_param(rate) | |||
| cdf = 1.0 - self.exp(-1. * rate * value) | |||
| @@ -251,7 +254,7 @@ class Exponential(Distribution): | |||
| rate_a (Tensor): rate of distribution a. Default: self.rate. | |||
| """ | |||
| check_distribution_name(dist, 'Exponential') | |||
| self.checktensor(rate_b, 'rate_b') | |||
| rate_b = self._check_value(rate_b, 'rate_b') | |||
| rate_b = self.cast(rate_b, self.parameter_type) | |||
| rate_a = self._check_param(rate) | |||
| return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 | |||
| @@ -102,7 +102,7 @@ class Geometric(Distribution): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| check_type(dtype, valid_dtype, "Geometric") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Geometric, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = mstype.float32 | |||
| if probs is not None: | |||
| @@ -150,7 +150,10 @@ class Geometric(Distribution): | |||
| Check availablity of distribution specific args probs1. | |||
| """ | |||
| if probs1 is not None: | |||
| self.checktensor(probs1, 'probs1') | |||
| if self.context_mode == 0: | |||
| self.checktensor(probs1, 'probs1') | |||
| else: | |||
| probs1 = self.checktensor(probs1, 'probs1') | |||
| return self.cast(probs1, self.parameter_type) | |||
| return self.probs if self.probs is not None else raise_none_error('probs1') | |||
| @@ -211,7 +214,7 @@ class Geometric(Distribution): | |||
| pmf(k) = probs0 ^k * probs1 if k >= 0; | |||
| pmf(k) = 0 if k < 0. | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| @@ -233,7 +236,7 @@ class Geometric(Distribution): | |||
| cdf(k) = 0 if k < 0. | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, mstype.float32) | |||
| value = self.floor(value) | |||
| probs1 = self._check_param(probs1) | |||
| @@ -256,7 +259,7 @@ class Geometric(Distribution): | |||
| KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) | |||
| """ | |||
| check_distribution_name(dist, 'Geometric') | |||
| self.checktensor(probs1_b, 'probs1_b') | |||
| probs1_b = self._check_value(probs1_b, 'probs1_b') | |||
| probs1_b = self.cast(probs1_b, self.parameter_type) | |||
| probs1_a = self._check_param(probs1) | |||
| probs0_a = 1.0 - probs1_a | |||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | |||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic | |||
| @@ -102,12 +102,12 @@ class Normal(Distribution): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, "Normal") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Normal, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = dtype | |||
| if mean is not None and sd is not None: | |||
| self._mean_value = convert_to_batch(mean, self.broadcast_shape, self.parameter_type) | |||
| self._sd_value = convert_to_batch(sd, self.broadcast_shape, self.parameter_type) | |||
| self._mean_value = cast_to_tensor(mean, self.parameter_type) | |||
| self._sd_value = cast_to_tensor(sd, self.parameter_type) | |||
| check_greater_zero(self._sd_value, "Standard deviation") | |||
| else: | |||
| self._mean_value = mean | |||
| @@ -139,12 +139,18 @@ class Normal(Distribution): | |||
| Check availablity of distribution specific args mean and sd. | |||
| """ | |||
| if mean is not None: | |||
| self.checktensor(mean, 'mean') | |||
| if self.context_mode == 0: | |||
| self.checktensor(mean, 'mean') | |||
| else: | |||
| mean = self.checktensor(mean, 'mean') | |||
| mean = self.cast(mean, self.parameter_type) | |||
| else: | |||
| mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') | |||
| if sd is not None: | |||
| self.checktensor(sd, 'sd') | |||
| if self.context_mode == 0: | |||
| self.checktensor(sd, 'sd') | |||
| else: | |||
| sd = self.checktensor(sd, 'sd') | |||
| sd = self.cast(sd, self.parameter_type) | |||
| else: | |||
| sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') | |||
| @@ -210,7 +216,7 @@ class Normal(Distribution): | |||
| .. math:: | |||
| L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| mean, sd = self._check_param(mean, sd) | |||
| unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) | |||
| @@ -229,7 +235,7 @@ class Normal(Distribution): | |||
| .. math:: | |||
| cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| mean, sd = self._check_param(mean, sd) | |||
| sqrt2 = self.sqrt(self.const(2.0)) | |||
| @@ -252,8 +258,8 @@ class Normal(Distribution): | |||
| 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) | |||
| """ | |||
| check_distribution_name(dist, 'Normal') | |||
| self.checktensor(mean_b, 'mean_b') | |||
| self.checktensor(sd_b, 'sd_b') | |||
| mean_b = self._check_value(mean_b, 'mean_b') | |||
| sd_b = self._check_value(sd_b, 'sd_b') | |||
| mean_b = self.cast(mean_b, self.parameter_type) | |||
| sd_b = self.cast(sd_b, self.parameter_type) | |||
| mean_a, sd_a = self._check_param(mean, sd) | |||
| @@ -46,10 +46,10 @@ class TransformedDistribution(Distribution): | |||
| Constructor of transformed_distribution class. | |||
| """ | |||
| param = dict(locals()) | |||
| validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name) | |||
| validator.check_value_type('distribution', distribution, [Distribution], name) | |||
| validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], type(self).__name__) | |||
| validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) | |||
| valid_dtype = mstype.number_type | |||
| check_type(dtype, valid_dtype, "transformed_distribution") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(TransformedDistribution, self).__init__(seed, dtype, name, param) | |||
| self._bijector = bijector | |||
| @@ -17,7 +17,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from .distribution import Distribution | |||
| from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | |||
| from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ | |||
| raise_none_error | |||
| from ._utils.custom_ops import exp_generic, log_generic | |||
| @@ -101,12 +101,12 @@ class Uniform(Distribution): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.float_type | |||
| check_type(dtype, valid_dtype, "Uniform") | |||
| check_type(dtype, valid_dtype, type(self).__name__) | |||
| super(Uniform, self).__init__(seed, dtype, name, param) | |||
| self.parameter_type = dtype | |||
| if low is not None and high is not None: | |||
| self._low = convert_to_batch(low, self.broadcast_shape, dtype) | |||
| self._high = convert_to_batch(high, self.broadcast_shape, dtype) | |||
| self._low = cast_to_tensor(low, dtype) | |||
| self._high = cast_to_tensor(high, dtype) | |||
| check_greater(self.low, self.high, "low value", "high value") | |||
| else: | |||
| self._low = low | |||
| @@ -142,12 +142,18 @@ class Uniform(Distribution): | |||
| Check availablity of distribution specific args low and high. | |||
| """ | |||
| if low is not None: | |||
| self.checktensor(low, 'low') | |||
| if self.context_mode == 0: | |||
| self.checktensor(low, 'low') | |||
| else: | |||
| low = self.checktensor(low, 'low') | |||
| low = self.cast(low, self.parameter_type) | |||
| else: | |||
| low = self.low if self.low is not None else raise_none_error('low') | |||
| if high is not None: | |||
| self.checktensor(high, 'high') | |||
| if self.context_mode == 0: | |||
| self.checktensor(high, 'high') | |||
| else: | |||
| high = self.checktensor(high, 'high') | |||
| high = self.cast(high, self.parameter_type) | |||
| else: | |||
| high = self.high if self.high is not None else raise_none_error('high') | |||
| @@ -231,7 +237,7 @@ class Uniform(Distribution): | |||
| pdf(x) = \frac{1.0}{high -low} if low <= x <= high; | |||
| pdf(x) = 0 if x > high; | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| low, high = self._check_param(low, high) | |||
| neg_ones = self.fill(self.dtype, self.shape(value), -1.0) | |||
| @@ -255,9 +261,9 @@ class Uniform(Distribution): | |||
| high_a (Tensor): upper bound of distribution a. Default: self.high. | |||
| """ | |||
| check_distribution_name(dist, 'Uniform') | |||
| self.checktensor(low_b, 'low_b') | |||
| low_b = self._check_value(low_b, 'low_b') | |||
| low_b = self.cast(low_b, self.parameter_type) | |||
| self.checktensor(high_b, 'high_b') | |||
| high_b = self._check_value(high_b, 'high_b') | |||
| high_b = self.cast(high_b, self.parameter_type) | |||
| low_a, high_a = self._check_param(low, high) | |||
| kl = self.log(high_b - low_b) - self.log(high_a - low_a) | |||
| @@ -278,7 +284,7 @@ class Uniform(Distribution): | |||
| cdf(x) = \frac{x - low}{high -low} if low <= x <= high; | |||
| cdf(x) = 1 if x > high; | |||
| """ | |||
| self.checktensor(value, 'value') | |||
| value = self._check_value(value, 'value') | |||
| value = self.cast(value, self.dtype) | |||
| low, high = self._check_param(low, high) | |||
| prob = (value - low) / (high - low) | |||