Browse Source

Fix errors in exp calculation logics

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
03dac9b621
10 changed files with 61 additions and 36 deletions
  1. +3
    -3
      mindspore/nn/probability/bijector/power_transform.py
  2. +3
    -3
      mindspore/nn/probability/bijector/softplus.py
  3. +2
    -1
      mindspore/nn/probability/distribution/_utils/__init__.py
  4. +33
    -8
      mindspore/nn/probability/distribution/_utils/custom_ops.py
  5. +3
    -3
      mindspore/nn/probability/distribution/bernoulli.py
  6. +5
    -5
      mindspore/nn/probability/distribution/exponential.py
  7. +3
    -3
      mindspore/nn/probability/distribution/geometric.py
  8. +4
    -4
      mindspore/nn/probability/distribution/normal.py
  9. +2
    -3
      mindspore/nn/probability/distribution/transformed_distribution.py
  10. +3
    -3
      mindspore/nn/probability/distribution/uniform.py

+ 3
- 3
mindspore/nn/probability/bijector/power_transform.py View File

@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
from .bijector import Bijector from .bijector import Bijector


class PowerTransform(Bijector): class PowerTransform(Bijector):
@@ -59,10 +59,10 @@ class PowerTransform(Bijector):
validator.check_number("power", power, 0, Rel.GE, self.name) validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power self._power = power
self.pow = P.Pow() self.pow = P.Pow()
self.exp = P.Exp()
self.exp = exp_by_step
self.expm1 = expm1_by_step
self.log = log_by_step self.log = log_by_step
self.log1p = log1p_by_step self.log1p = log1p_by_step
self.expm1 = expm1_by_step


self.checktensor = CheckTensor() self.checktensor = CheckTensor()




+ 3
- 3
mindspore/nn/probability/bijector/softplus.py View File

@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import log_by_step, expm1_by_step
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
from .bijector import Bijector from .bijector import Bijector


class Softplus(Bijector): class Softplus(Bijector):
@@ -59,10 +59,10 @@ class Softplus(Bijector):
super(Softplus, self).__init__(name=name, param=param) super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness) self._sharpness = cast_to_tensor(sharpness)


self.abs = P.Abs()
self.exp = P.Exp()
self.exp = exp_by_step
self.log = log_by_step self.log = log_by_step
self.expm1 = expm1_by_step self.expm1 = expm1_by_step
self.abs = P.Abs()
self.fill = P.Fill() self.fill = P.Fill()
self.greater = P.Greater() self.greater = P.Greater()
self.less = P.Less() self.less = P.Less()


+ 2
- 1
mindspore/nn/probability/distribution/_utils/__init__.py View File

@@ -28,7 +28,8 @@ __all__ = [
'check_scalar_from_param', 'check_scalar_from_param',
'check_prob', 'check_prob',
'check_type', 'check_type',
'exp_by_step',
'expm1_by_step',
'log_by_step', 'log_by_step',
'log1p_by_step', 'log1p_by_step',
'expm1_by_step',
] ]

+ 33
- 8
mindspore/nn/probability/distribution/_utils/custom_ops.py View File

@@ -17,10 +17,36 @@ import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype


def exp_by_step(input_x):
"""
Log op on Ascend doesn't supprot int types.
Fix this with casting the type.
"""
exp = P.Exp()
cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()

if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
return exp(input_x)

def expm1_by_step(input_x):
"""
Expm1 ops under GPU context.
"""
return exp_by_step(input_x) - 1.0

def log_by_step(input_x): def log_by_step(input_x):
""" """
Log op on Ascend is calculated as log(abs(x)). Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan. Fix this with putting negative values as nan.
And log op on Ascend doesn't supprot int types.
Fix this with casting the type.
""" """
log = P.Log() log = P.Log()
less = P.Less() less = P.Less()
@@ -30,8 +56,14 @@ def log_by_step(input_x):
dtype = P.DType() dtype = P.DType()
shape = P.Shape() shape = P.Shape()
select = P.Select() select = P.Select()
checktype = P.IsSubClass()


input_x = cast(input_x, mstype.float32)
if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
nan = fill(dtype(input_x), shape(input_x), np.nan) nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf) inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0) neg_x = less(input_x, 0.0)
@@ -45,10 +77,3 @@ def log1p_by_step(x):
Log1p ops on GPU device or when device_target == GPU. Log1p ops on GPU device or when device_target == GPU.
""" """
return log_by_step(x + 1.0) return log_by_step(x + 1.0)

def expm1_by_step(input_x):
"""
Expm1 ops under GPU context.
"""
exp = P.Exp()
return exp(input_x) - 1.0

+ 3
- 3
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step


class Bernoulli(Distribution): class Bernoulli(Distribution):
""" """
@@ -108,15 +108,15 @@ class Bernoulli(Distribution):
self._probs = probs self._probs = probs


# ops needed for the class # ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.erf = P.Erf() self.erf = P.Erf()
self.exp = P.Exp()
self.floor = P.Floor() self.floor = P.Floor()
self.fill = P.Fill() self.fill = P.Fill()
self.log = log_by_step
self.less = P.Less() self.less = P.Less()
self.shape = P.Shape() self.shape = P.Shape()
self.select = P.Select() self.select = P.Select()


+ 5
- 5
mindspore/nn/probability/distribution/exponential.py View File

@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step


class Exponential(Distribution): class Exponential(Distribution):
""" """
@@ -112,14 +112,14 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny


# ops needed for the class # ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill() self.fill = P.Fill()
self.less = P.Less() self.less = P.Less()
self.log = log_by_step
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
@@ -277,8 +277,8 @@ class Exponential(Distribution):
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed) sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate
value = self.cast(sample, self.dtype)
sample = self.log(sample_uniform) / rate
value = self.cast(-sample, self.dtype)
if origin_shape == (): if origin_shape == ():
value = self.squeeze(value) value = self.squeeze(value)
return value return value

+ 3
- 3
mindspore/nn/probability/distribution/geometric.py View File

@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step


class Geometric(Distribution): class Geometric(Distribution):
""" """
@@ -113,16 +113,16 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny


# ops needed for the class # ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill() self.fill = P.Fill()
self.floor = P.Floor() self.floor = P.Floor()
self.issubclass = P.IsSubClass() self.issubclass = P.IsSubClass()
self.less = P.Less() self.less = P.Less()
self.log = log_by_step
self.pow = P.Pow() self.pow = P.Pow()
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()


+ 4
- 4
mindspore/nn/probability/distribution/normal.py View File

@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import log_by_step, expm1_by_step
from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step


class Normal(Distribution): class Normal(Distribution):
""" """
@@ -114,14 +114,14 @@ class Normal(Distribution):
self._sd_value = sd self._sd_value = sd


#ops needed for the class #ops needed for the class
self.exp = exp_by_step
self.expm1 = expm1_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.erf = P.Erf() self.erf = P.Erf()
self.exp = P.Exp()
self.expm1 = expm1_by_step
self.fill = P.Fill() self.fill = P.Fill()
self.log = log_by_step
self.shape = P.Shape() self.shape = P.Shape()
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()


+ 2
- 3
mindspore/nn/probability/distribution/transformed_distribution.py View File

@@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Transformed Distribution""" """Transformed Distribution"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error from ._utils.utils import check_type, raise_not_impl_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step


class TransformedDistribution(Distribution): class TransformedDistribution(Distribution):
""" """
@@ -56,7 +55,7 @@ class TransformedDistribution(Distribution):
self._bijector = bijector self._bijector = bijector
self._distribution = distribution self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = P.Exp()
self.exp = exp_by_step
self.log = log_by_step self.log = log_by_step


@property @property


+ 3
- 3
mindspore/nn/probability/distribution/uniform.py View File

@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step


class Uniform(Distribution): class Uniform(Distribution):
""" """
@@ -113,15 +113,15 @@ class Uniform(Distribution):
self._high = high self._high = high


# ops needed for the class # ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill() self.fill = P.Fill()
self.less = P.Less() self.less = P.Less()
self.lessequal = P.LessEqual() self.lessequal = P.LessEqual()
self.log = log_by_step
self.logicaland = P.LogicalAnd() self.logicaland = P.LogicalAnd()
self.select = P.Select() self.select = P.Select()
self.shape = P.Shape() self.shape = P.Shape()


Loading…
Cancel
Save