Merge pull request !7048 from peixu_ren/custom_bijectortags/v1.1.0
| @@ -15,6 +15,7 @@ | |||
| """Bijector""" | |||
| from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import Validator as validator | |||
| from ..distribution._utils.utils import CheckTensor | |||
| from ..distribution import Distribution | |||
| @@ -62,6 +63,10 @@ class Bijector(Cell): | |||
| self.context_mode = context.get_context('mode') | |||
| self.checktensor = CheckTensor() | |||
| # ops needed for the base class | |||
| self.cast_base = P.Cast() | |||
| self.dtype_base = P.DType() | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| @@ -91,6 +96,10 @@ class Bijector(Cell): | |||
| return value | |||
| return self.checktensor(value, name) | |||
| def cast_param_by_value(self, value, para): | |||
| local = self.cast_base(para, self.dtype_base(value)) | |||
| return local | |||
| def forward(self, *args, **kwargs): | |||
| """ | |||
| Forward transformation: transform the input value to another distribution. | |||
| @@ -69,6 +69,8 @@ class PowerTransform(Bijector): | |||
| validator.check_number("power", power, 0, Rel.GE, self.name) | |||
| self._power = power | |||
| self.pow = P.Pow() | |||
| self.dtypeop = P.DType() | |||
| self.cast = P.Cast() | |||
| self.exp = exp_generic | |||
| self.expm1 = expm1_generic | |||
| self.log = log_generic | |||
| @@ -87,15 +89,21 @@ class PowerTransform(Bijector): | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| if self.power == 0: | |||
| return self.exp(x) | |||
| return self.exp(self.log1p(x * self.power) / self.power) | |||
| power_local = self.cast_param_by_value(x, self.power) | |||
| if power_local == 0: | |||
| forward_v = self.exp(x) | |||
| else: | |||
| forward_v = self.exp(self.log1p(x * power_local) / power_local) | |||
| return forward_v | |||
| def _inverse(self, y): | |||
| y = self._check_value(y, 'value') | |||
| if self.power == 0: | |||
| return self.log(y) | |||
| return self.expm1(self.log(y) * self.power) / self.power | |||
| power_local = self.cast_param_by_value(y, self.power) | |||
| if power_local == 0: | |||
| inverse_v = self.log(y) | |||
| else: | |||
| inverse_v = self.expm1(self.log(y) * power_local) / power_local | |||
| return inverse_v | |||
| def _forward_log_jacobian(self, x): | |||
| r""" | |||
| @@ -110,9 +118,12 @@ class PowerTransform(Bijector): | |||
| \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| if self.power == 0: | |||
| return x | |||
| return (1. / self.power - 1) * self.log1p(x * self.power) | |||
| power_local = self.cast_param_by_value(x, self.power) | |||
| if power_local == 0: | |||
| forward_log_j = x | |||
| else: | |||
| forward_log_j = (1. / power_local - 1) * self.log1p(x * power_local) | |||
| return forward_log_j | |||
| def _inverse_log_jacobian(self, y): | |||
| r""" | |||
| @@ -127,4 +138,6 @@ class PowerTransform(Bijector): | |||
| \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| return (self.power - 1) * self.log(y) | |||
| power_local = self.cast_param_by_value(y, self.power) | |||
| inverse_log_j = (power_local - 1) * self.log(y) | |||
| return inverse_log_j | |||
| @@ -76,6 +76,8 @@ class ScalarAffine(Bijector): | |||
| self.abs = P.Abs() | |||
| self.oneslike = P.OnesLike() | |||
| self.dtypeop = P.DType() | |||
| self.cast = P.Cast() | |||
| self.log = log_generic | |||
| @property | |||
| @@ -99,7 +101,10 @@ class ScalarAffine(Bijector): | |||
| f(x) = a * x + b | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| return self.scale * x + self.shift * self.oneslike(x) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| shift_local = self.cast_param_by_value(x, self.shift) | |||
| forward_v = scale_local * x + shift_local * self.oneslike(x) | |||
| return forward_v | |||
| def _inverse(self, y): | |||
| r""" | |||
| @@ -107,7 +112,10 @@ class ScalarAffine(Bijector): | |||
| f(y) = \frac{y - b}{a} | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| return (y - self.shift) / self.scale | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| shift_local = self.cast_param_by_value(y, self.shift) | |||
| inverse_v = (y - shift_local) / scale_local | |||
| return inverse_v | |||
| def _forward_log_jacobian(self, x): | |||
| r""" | |||
| @@ -117,7 +125,9 @@ class ScalarAffine(Bijector): | |||
| \log(f'(x)) = \log(a) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| return self.log(self.abs(self.scale)) | |||
| scale_local = self.cast_param_by_value(x, self.scale) | |||
| forward_log_j = self.log(self.abs(scale_local)) | |||
| return forward_log_j | |||
| def _inverse_log_jacobian(self, y): | |||
| r""" | |||
| @@ -127,4 +137,6 @@ class ScalarAffine(Bijector): | |||
| \log(f'(x)) = - \log(a) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| return -1. * self.log(self.abs(self.scale)) | |||
| scale_local = self.cast_param_by_value(y, self.scale) | |||
| inverse_log_j = -1. * self.log(self.abs(scale_local)) | |||
| return inverse_log_j | |||
| @@ -71,6 +71,7 @@ class Softplus(Bijector): | |||
| self.expm1 = expm1_generic | |||
| self.abs = P.Abs() | |||
| self.dtypeop = P.DType() | |||
| self.cast = P.Cast() | |||
| self.fill = P.Fill() | |||
| self.greater = P.Greater() | |||
| self.less = P.Less() | |||
| @@ -125,8 +126,10 @@ class Softplus(Bijector): | |||
| def _forward(self, x): | |||
| x = self._check_value(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.softplus(scaled_value) / self.sharpness | |||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | |||
| scaled_value = sharpness_local * x | |||
| forward_v = self.softplus(scaled_value) / sharpness_local | |||
| return forward_v | |||
| def _inverse(self, y): | |||
| r""" | |||
| @@ -135,8 +138,10 @@ class Softplus(Bijector): | |||
| f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return self.inverse_softplus(scaled_value) / self.sharpness | |||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | |||
| scaled_value = sharpness_local * y | |||
| inverse_v = self.inverse_softplus(scaled_value) / sharpness_local | |||
| return inverse_v | |||
| def _forward_log_jacobian(self, x): | |||
| r""" | |||
| @@ -146,8 +151,10 @@ class Softplus(Bijector): | |||
| \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) | |||
| """ | |||
| x = self._check_value(x, 'value') | |||
| scaled_value = self.sharpness * x | |||
| return self.log_sigmoid(scaled_value) | |||
| sharpness_local = self.cast_param_by_value(x, self.sharpness) | |||
| scaled_value = sharpness_local * x | |||
| forward_log_j = self.log_sigmoid(scaled_value) | |||
| return forward_log_j | |||
| def _inverse_log_jacobian(self, y): | |||
| r""" | |||
| @@ -157,5 +164,7 @@ class Softplus(Bijector): | |||
| \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) | |||
| """ | |||
| y = self._check_value(y, 'value') | |||
| scaled_value = self.sharpness * y | |||
| return scaled_value - self.inverse_softplus(scaled_value) | |||
| sharpness_local = self.cast_param_by_value(y, self.sharpness) | |||
| scaled_value = sharpness_local * y | |||
| inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) | |||
| return inverse_log_j | |||
| @@ -66,10 +66,10 @@ def normal(shape, mean, stddev, seed=None): | |||
| Args: | |||
| shape (tuple): The shape of random tensor to be generated. | |||
| mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. | |||
| with float32 data type. | |||
| stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0. | |||
| with float32 data type. | |||
| mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak, | |||
| with data type in [int8, int16, int32, int64, float16, float32]. | |||
| stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0, | |||
| with data type in [int8, int16, int32, int64, float16, float32]. | |||
| seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers. | |||
| must be non-negative. Default: None, which will be treated as 0. | |||
| @@ -86,8 +86,8 @@ def normal(shape, mean, stddev, seed=None): | |||
| """ | |||
| mean_dtype = F.dtype(mean) | |||
| stddev_dtype = F.dtype(stddev) | |||
| const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") | |||
| const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") | |||
| const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') | |||
| const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') | |||
| seed1, seed2 = get_seed(seed, "normal") | |||
| stdnormal = P.StandardNormal(seed1, seed2) | |||
| random_normal = stdnormal(shape) | |||