| @@ -15,7 +15,6 @@ | |||||
| """Softplus Bijector""" | """Softplus Bijector""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.layer.activation import LogSigmoid | from mindspore.nn.layer.activation import LogSigmoid | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import cast_to_tensor | from ..distribution._utils.utils import cast_to_tensor | ||||
| @@ -71,6 +70,7 @@ class Softplus(Bijector): | |||||
| self.log = log_generic | self.log = log_generic | ||||
| self.expm1 = expm1_generic | self.expm1 = expm1_generic | ||||
| self.abs = P.Abs() | self.abs = P.Abs() | ||||
| self.dtypeop = P.DType() | |||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.greater = P.Greater() | self.greater = P.Greater() | ||||
| self.less = P.Less() | self.less = P.Less() | ||||
| @@ -90,7 +90,7 @@ class Softplus(Bijector): | |||||
| too_large = self.greater(x, -self.threshold) | too_large = self.greater(x, -self.threshold) | ||||
| too_small_value = self.exp(x) | too_small_value = self.exp(x) | ||||
| too_large_value = x | too_large_value = x | ||||
| ones = self.fill(mstype.float32, self.shape(x), 1.0) | |||||
| ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) | |||||
| too_small_or_too_large = self.logicalor(too_small, too_large) | too_small_or_too_large = self.logicalor(too_small, too_large) | ||||
| x = self.select(too_small_or_too_large, ones, x) | x = self.select(too_small_or_too_large, ones, x) | ||||
| y = self.log(self.exp(x) + 1.0) | y = self.log(self.exp(x) + 1.0) | ||||
| @@ -106,7 +106,7 @@ class Softplus(Bijector): | |||||
| too_large = self.greater(x, -self.threshold) | too_large = self.greater(x, -self.threshold) | ||||
| too_small_value = self.log(x) | too_small_value = self.log(x) | ||||
| too_large_value = x | too_large_value = x | ||||
| ones = self.fill(mstype.float32, self.shape(x), 1.0) | |||||
| ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) | |||||
| too_small_or_too_large = self.logicalor(too_small, too_large) | too_small_or_too_large = self.logicalor(too_small, too_large) | ||||
| x = self.select(too_small_or_too_large, ones, x) | x = self.select(too_small_or_too_large, ones, x) | ||||
| y = x + self.log(self.abs(self.expm1(-x))) | y = x + self.log(self.abs(self.expm1(-x))) | ||||
| @@ -24,8 +24,11 @@ def exp_generic(input_x): | |||||
| """ | """ | ||||
| exp = P.Exp() | exp = P.Exp() | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| dtype = P.DType() | |||||
| checktype = P.IsSubClass() | |||||
| input_x = cast(input_x, mstype.float32) | |||||
| if not checktype(dtype(input_x), mstype.float_): | |||||
| input_x = cast(input_x, mstype.float32) | |||||
| return exp(input_x) | return exp(input_x) | ||||
| @@ -51,8 +54,10 @@ def log_generic(input_x): | |||||
| dtype = P.DType() | dtype = P.DType() | ||||
| shape = P.Shape() | shape = P.Shape() | ||||
| select = P.Select() | select = P.Select() | ||||
| checktype = P.IsSubClass() | |||||
| input_x = cast(input_x, mstype.float32) | |||||
| if not checktype(dtype(input_x), mstype.float_): | |||||
| input_x = cast(input_x, mstype.float32) | |||||
| nan = fill(dtype(input_x), shape(input_x), np.nan) | nan = fill(dtype(input_x), shape(input_x), np.nan) | ||||
| inf = fill(dtype(input_x), shape(input_x), np.inf) | inf = fill(dtype(input_x), shape(input_x), np.inf) | ||||
| neg_x = less(input_x, 0.0) | neg_x = less(input_x, 0.0) | ||||
| @@ -222,7 +222,7 @@ class Bernoulli(Distribution): | |||||
| pmf(k) = probs0 if k = 0; | pmf(k) = probs0 if k = 0; | ||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | |||||
| value = self.cast(value, self.parameter_type) | |||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||
| return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | return self.log(probs1) * value + self.log(probs0) * (1.0 - value) | ||||
| @@ -241,7 +241,7 @@ class Bernoulli(Distribution): | |||||
| cdf(k) = 1 if k >=1; | cdf(k) = 1 if k >=1; | ||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | |||||
| value = self.cast(value, self.parameter_type) | |||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| prob_type = self.dtypeop(probs1) | prob_type = self.dtypeop(probs1) | ||||
| @@ -225,7 +225,7 @@ class Geometric(Distribution): | |||||
| pmf(k) = 0 if k < 0. | pmf(k) = 0 if k < 0. | ||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | |||||
| value = self.cast(value, self.parameter_type) | |||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) | ||||
| @@ -247,7 +247,7 @@ class Geometric(Distribution): | |||||
| """ | """ | ||||
| value = self._check_value(value, 'value') | value = self._check_value(value, 'value') | ||||
| value = self.cast(value, mstype.float32) | |||||
| value = self.cast(value, self.parameter_type) | |||||
| value = self.floor(value) | value = self.floor(value) | ||||
| probs1 = self._check_param_type(probs1) | probs1 = self._check_param_type(probs1) | ||||
| probs0 = 1.0 - probs1 | probs0 = 1.0 - probs1 | ||||