Merge pull request !5092 from peixu_ren/custom_bijectortags/v1.0.0
| @@ -17,14 +17,14 @@ from mindspore.ops import operations as P | |||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from ..distribution._utils.utils import CheckTensor | from ..distribution._utils.utils import CheckTensor | ||||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step | |||||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| class PowerTransform(Bijector): | class PowerTransform(Bijector): | ||||
| r""" | r""" | ||||
| Power Bijector. | Power Bijector. | ||||
| This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power. | |||||
| This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c >= 0 is the power. | |||||
| The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`. | The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`. | ||||
| @@ -61,10 +61,10 @@ class PowerTransform(Bijector): | |||||
| validator.check_number("power", power, 0, Rel.GE, self.name) | validator.check_number("power", power, 0, Rel.GE, self.name) | ||||
| self._power = power | self._power = power | ||||
| self.pow = P.Pow() | self.pow = P.Pow() | ||||
| self.exp = exp_by_step | |||||
| self.expm1 = expm1_by_step | |||||
| self.log = log_by_step | |||||
| self.log1p = log1p_by_step | |||||
| self.exp = exp_generic | |||||
| self.expm1 = expm1_generic | |||||
| self.log = log_generic | |||||
| self.log1p = log1p_generic | |||||
| self.checktensor = CheckTensor() | self.checktensor = CheckTensor() | ||||
| @@ -16,7 +16,7 @@ | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | from ..distribution._utils.utils import cast_to_tensor, CheckTensor | ||||
| from ..distribution._utils.custom_ops import log_by_step | |||||
| from ..distribution._utils.custom_ops import log_generic | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -69,7 +69,7 @@ class ScalarAffine(Bijector): | |||||
| param=param) | param=param) | ||||
| self.abs = P.Abs() | self.abs = P.Abs() | ||||
| self.log = log_by_step | |||||
| self.log = log_generic | |||||
| self.checktensor = CheckTensor() | self.checktensor = CheckTensor() | ||||
| @@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.layer.activation import LogSigmoid | from mindspore.nn.layer.activation import LogSigmoid | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from ..distribution._utils.utils import cast_to_tensor, CheckTensor | from ..distribution._utils.utils import cast_to_tensor, CheckTensor | ||||
| from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step | |||||
| from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic | |||||
| from .bijector import Bijector | from .bijector import Bijector | ||||
| @@ -61,9 +61,9 @@ class Softplus(Bijector): | |||||
| super(Softplus, self).__init__(name=name, param=param) | super(Softplus, self).__init__(name=name, param=param) | ||||
| self._sharpness = cast_to_tensor(sharpness) | self._sharpness = cast_to_tensor(sharpness) | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.expm1 = expm1_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.expm1 = expm1_generic | |||||
| self.abs = P.Abs() | self.abs = P.Abs() | ||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.greater = P.Greater() | self.greater = P.Greater() | ||||
| @@ -28,8 +28,10 @@ __all__ = [ | |||||
| 'check_scalar_from_param', | 'check_scalar_from_param', | ||||
| 'check_prob', | 'check_prob', | ||||
| 'check_type', | 'check_type', | ||||
| 'exp_by_step', | |||||
| 'expm1_by_step', | |||||
| 'log_by_step', | |||||
| 'log1p_by_step', | |||||
| 'exp_generic', | |||||
| 'expm1_generic', | |||||
| 'log_generic', | |||||
| 'log1p_generic', | |||||
| 'erf_generic', | |||||
| 'erfc_generic', | |||||
| ] | ] | ||||
| @@ -17,8 +17,7 @@ import numpy as np | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| def exp_by_step(input_x): | |||||
| def exp_generic(input_x): | |||||
| """ | """ | ||||
| Log op on Ascend doesn't supprot int types. | Log op on Ascend doesn't supprot int types. | ||||
| Fix this with casting the type. | Fix this with casting the type. | ||||
| @@ -30,14 +29,14 @@ def exp_by_step(input_x): | |||||
| return exp(input_x) | return exp(input_x) | ||||
| def expm1_by_step(input_x): | |||||
| def expm1_generic(input_x): | |||||
| """ | """ | ||||
| Expm1 ops under GPU context. | Expm1 ops under GPU context. | ||||
| """ | """ | ||||
| return exp_by_step(input_x) - 1.0 | |||||
| return exp_generic(input_x) - 1.0 | |||||
| def log_by_step(input_x): | |||||
| def log_generic(input_x): | |||||
| """ | """ | ||||
| Log op on Ascend is calculated as log(abs(x)). | Log op on Ascend is calculated as log(abs(x)). | ||||
| Fix this with putting negative values as nan. | Fix this with putting negative values as nan. | ||||
| @@ -63,8 +62,166 @@ def log_by_step(input_x): | |||||
| return select(neg_x, nan, result) | return select(neg_x, nan, result) | ||||
| def log1p_by_step(x): | |||||
| def log1p_generic(x): | |||||
| """ | """ | ||||
| Log1p ops on GPU device or when device_target == GPU. | Log1p ops on GPU device or when device_target == GPU. | ||||
| """ | """ | ||||
| return log_by_step(x + 1.0) | |||||
| return log_generic(x + 1.0) | |||||
| def _evaluate_polynomial(x, coefficients): | |||||
| poly = 0 | |||||
| for co in coefficients: | |||||
| poly = poly * x + co | |||||
| return poly | |||||
| def erf_f32_generic(x): | |||||
| """ | |||||
| Calculate erf for dtype of f32 | |||||
| """ | |||||
| k_erf_tcoefficient = [+7.853861353153693e-5, | |||||
| -8.010193625184903e-4, | |||||
| +5.188327685732524e-3, | |||||
| -2.685381193529856e-2, | |||||
| +1.128358514861418e-1, | |||||
| -3.761262582423300e-1, | |||||
| +1.128379165726710e+0] | |||||
| poly = _evaluate_polynomial(x * x, k_erf_tcoefficient) | |||||
| return x * poly | |||||
| def erf_f64_generic(x): | |||||
| """ | |||||
| Calculate erf for dtype of f64 | |||||
| """ | |||||
| k_erf_tcoefficient = [9.60497373987051638749e0, | |||||
| 9.00260197203842689217e1, | |||||
| 2.23200534594684319226e3, | |||||
| 7.00332514112805075473e3, | |||||
| 5.55923013010394962768e4] | |||||
| k_erf_ucoefficient = [1.00000000000000000000e0, | |||||
| 3.35617141647503099647e1, | |||||
| 5.21357949780152679795e2, | |||||
| 4.59432382970980127987e3, | |||||
| 2.26290000613890934246e4, | |||||
| 4.92673942608635921086e4] | |||||
| z = x * x | |||||
| poly1 = _evaluate_polynomial(z, k_erf_tcoefficient) | |||||
| poly2 = _evaluate_polynomial(z, k_erf_ucoefficient) | |||||
| return x * poly1 / poly2 | |||||
| def erfc_f32_generic(x): | |||||
| """ | |||||
| Calculate erfc for dtype of f32 | |||||
| """ | |||||
| k_maxlog = 88.72283905206835 | |||||
| k_erfc_pcoefficient = [+2.326819970068386e-2, | |||||
| -1.387039388740657e-1, | |||||
| +3.687424674597105e-1, | |||||
| -5.824733027278666e-1, | |||||
| +6.210004621745983e-1, | |||||
| -4.944515323274145e-1, | |||||
| +3.404879937665872e-1, | |||||
| -2.741127028184656e-1, | |||||
| +5.638259427386472e-1] | |||||
| k_erfc_rcoefficient = [-1.047766399936249e+1, | |||||
| +1.297719955372516e+1, | |||||
| -7.495518717768503e+0, | |||||
| +2.921019019210786e+0, | |||||
| -1.015265279202700e+0, | |||||
| +4.218463358204948e-1, | |||||
| -2.820767439740514e-1, | |||||
| +5.641895067754075e-1] | |||||
| abs_cal = P.Abs() | |||||
| select = P.Select() | |||||
| less = P.Less() | |||||
| fill = P.Fill() | |||||
| dtype = P.DType() | |||||
| shape = P.Shape() | |||||
| abs_x = abs_cal(x) | |||||
| z = exp_generic(-x * x) | |||||
| q = 1 / abs_x | |||||
| y = q * q | |||||
| poly1 = _evaluate_polynomial(y, k_erfc_pcoefficient) | |||||
| poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient) | |||||
| p = select(less(abs_x, 2.0), poly1, poly2) | |||||
| y = z * q * p | |||||
| zeros = fill(dtype(x), shape(x), 0) | |||||
| y_clamp = select(less(z, -k_maxlog), zeros, y) | |||||
| return select(less(x, 0), 2.0 - y_clamp, y_clamp) | |||||
| def erfc_f64_generic(x): | |||||
| """ | |||||
| Calculate erfc for dtype of f64 | |||||
| """ | |||||
| k_maxlog = 7.09782712893383996843e2 | |||||
| k_erfc_pcoefficient = [2.46196981473530512524e-10, | |||||
| 5.64189564831068821977e-1, | |||||
| 7.46321056442269912687e0, | |||||
| 4.86371970985681366614e1, | |||||
| 1.96520832956077098242e2, | |||||
| 5.26445194995477358631e2, | |||||
| 9.34528527171957607540e2, | |||||
| 1.02755188689515710272e3, | |||||
| 5.57535335369399327526e2] | |||||
| k_erfc_qcoefficient = [1.00000000000000000000e0, | |||||
| 1.32281951154744992508e1, | |||||
| 8.67072140885989742329e1, | |||||
| 3.54937778887819891062e2, | |||||
| 9.75708501743205489753e2, | |||||
| 1.82390916687909736289e3, | |||||
| 2.24633760818710981792e3, | |||||
| 1.65666309194161350182e3, | |||||
| 5.57535340817727675546e2] | |||||
| k_erfc_rcoefficient = [5.64189583547755073984e-1, | |||||
| 1.27536670759978104416e0, | |||||
| 5.01905042251180477414e0, | |||||
| 6.16021097993053585195e0, | |||||
| 7.40974269950448939160e0, | |||||
| 2.97886665372100240670e0] | |||||
| k_erfc_scoefficient = [1.00000000000000000000e0, | |||||
| 2.26052863220117276590e0, | |||||
| 9.39603524938001434673e0, | |||||
| 1.20489539808096656605e1, | |||||
| 1.70814450747565897222e1, | |||||
| 9.60896809063285878198e0, | |||||
| 3.36907645100081516050e02] | |||||
| abs_cal = P.Abs() | |||||
| select = P.Select() | |||||
| less = P.Less() | |||||
| fill = P.Fill() | |||||
| dtype = P.DType() | |||||
| shape = P.Shape() | |||||
| abs_x = abs_cal(x) | |||||
| z = -x * x | |||||
| exp_z = exp_generic(z) | |||||
| temp1 = exp_z * _evaluate_polynomial(abs_x, k_erfc_pcoefficient) / _evaluate_polynomial(abs_x, k_erfc_qcoefficient) | |||||
| temp2 = exp_z * _evaluate_polynomial(abs_x, k_erfc_rcoefficient) / _evaluate_polynomial(abs_x, k_erfc_scoefficient) | |||||
| y = select(less(abs_x, 8.0), temp1, temp2) | |||||
| zeros = fill(dtype(x), shape(x), 0) | |||||
| y_clamp = select(less(z, k_maxlog), zeros, y) | |||||
| poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient) | |||||
| p = select(less(abs_x, 2.0), poly1, poly2) | |||||
| y = z * q * p | |||||
| zeros = fill(dtype(x), shape(x), 0) | |||||
| y_clamp = select(less(z, -k_maxlog), zeros, y) | |||||
| return select(less(x, 0), 2.0 - y_clamp, y_clamp) | |||||
| def erfc_generic(x): | |||||
| select = P.Select() | |||||
| greater = P.Greater() | |||||
| abs_cal = P.Abs() | |||||
| return select(greater(abs_cal(x), 1), erfc_f32_generic(x), 1 - erf_f32_generic(x)) | |||||
| def erf_generic(x): | |||||
| select = P.Select() | |||||
| less = P.Less() | |||||
| abs_cal = P.Abs() | |||||
| return select(less(abs_cal(x), 1), erf_f32_generic(x), 1 - erfc_f32_generic(x)) | |||||
| @@ -18,7 +18,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, log_generic, erf_generic | |||||
| class Bernoulli(Distribution): | class Bernoulli(Distribution): | ||||
| @@ -109,13 +109,13 @@ class Bernoulli(Distribution): | |||||
| self._probs = probs | self._probs = probs | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.erf = erf_generic | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.dtypeop = P.DType() | self.dtypeop = P.DType() | ||||
| self.erf = P.Erf() | |||||
| self.floor = P.Floor() | self.floor = P.Floor() | ||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.less = P.Less() | self.less = P.Less() | ||||
| @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, log_generic | |||||
| class Exponential(Distribution): | class Exponential(Distribution): | ||||
| """ | """ | ||||
| @@ -112,8 +112,8 @@ class Exponential(Distribution): | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, log_generic | |||||
| class Geometric(Distribution): | class Geometric(Distribution): | ||||
| @@ -114,8 +114,8 @@ class Geometric(Distribution): | |||||
| self.minval = np.finfo(np.float).tiny | self.minval = np.finfo(np.float).tiny | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic | |||||
| class Normal(Distribution): | class Normal(Distribution): | ||||
| """ | """ | ||||
| @@ -114,13 +114,13 @@ class Normal(Distribution): | |||||
| self._sd_value = sd | self._sd_value = sd | ||||
| #ops needed for the class | #ops needed for the class | ||||
| self.exp = exp_by_step | |||||
| self.expm1 = expm1_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.expm1 = expm1_generic | |||||
| self.log = log_generic | |||||
| self.erf = erf_generic | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||
| self.erf = P.Erf() | |||||
| self.fill = P.Fill() | self.fill = P.Fill() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.sq = P.Square() | self.sq = P.Square() | ||||
| @@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import check_type, raise_not_impl_error | from ._utils.utils import check_type, raise_not_impl_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, log_generic | |||||
| class TransformedDistribution(Distribution): | class TransformedDistribution(Distribution): | ||||
| """ | """ | ||||
| @@ -55,8 +55,8 @@ class TransformedDistribution(Distribution): | |||||
| self._bijector = bijector | self._bijector = bijector | ||||
| self._distribution = distribution | self._distribution = distribution | ||||
| self._is_linear_transformation = bijector.is_constant_jacobian | self._is_linear_transformation = bijector.is_constant_jacobian | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| @property | @property | ||||
| def bijector(self): | def bijector(self): | ||||
| @@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype | |||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ | ||||
| raise_none_error | raise_none_error | ||||
| from ._utils.custom_ops import exp_by_step, log_by_step | |||||
| from ._utils.custom_ops import exp_generic, log_generic | |||||
| class Uniform(Distribution): | class Uniform(Distribution): | ||||
| """ | """ | ||||
| @@ -113,8 +113,8 @@ class Uniform(Distribution): | |||||
| self._high = high | self._high = high | ||||
| # ops needed for the class | # ops needed for the class | ||||
| self.exp = exp_by_step | |||||
| self.log = log_by_step | |||||
| self.exp = exp_generic | |||||
| self.log = log_generic | |||||
| self.squeeze = P.Squeeze(0) | self.squeeze = P.Squeeze(0) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.const = P.ScalarToArray() | self.const = P.ScalarToArray() | ||||