Browse Source

!8234 Remove expm1_generic and log1p_generic from PP utils

From: @peixu_ren
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7f725b93a3
9 changed files with 17 additions and 25 deletions
  1. +3
    -3
      mindspore/nn/probability/bijector/power_transform.py
  2. +2
    -2
      mindspore/nn/probability/bijector/softplus.py
  3. +0
    -2
      mindspore/nn/probability/distribution/_utils/__init__.py
  4. +0
    -7
      mindspore/nn/probability/distribution/_utils/custom_ops.py
  5. +2
    -2
      mindspore/nn/probability/distribution/gumbel.py
  6. +2
    -2
      mindspore/nn/probability/distribution/log_normal.py
  7. +3
    -3
      mindspore/nn/probability/distribution/logistic.py
  8. +2
    -2
      mindspore/nn/probability/distribution/normal.py
  9. +3
    -2
      tests/ut/python/nn/probability/distribution/test_lognormal.py

+ 3
- 3
mindspore/nn/probability/bijector/power_transform.py View File

@@ -15,7 +15,7 @@
"""Power Bijector"""
from mindspore.ops import operations as P
from ..distribution._utils.utils import check_greater_equal_zero
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector


@@ -73,9 +73,9 @@ class PowerTransform(Bijector):
self.dtypeop = P.DType()
self.cast = P.Cast()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.log1p = log1p_generic
self.log1p = P.Log1p()

@property
def power(self):


+ 2
- 2
mindspore/nn/probability/bijector/softplus.py View File

@@ -16,7 +16,7 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector


@@ -65,7 +65,7 @@ class Softplus(Bijector):

self.exp = exp_generic
self.log = log_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.abs = P.Abs()
self.dtypeop = P.DType()
self.cast = P.Cast()


+ 0
- 2
mindspore/nn/probability/distribution/_utils/__init__.py View File

@@ -25,9 +25,7 @@ __all__ = [
'check_greater_zero',
'check_prob',
'exp_generic',
'expm1_generic',
'log_generic',
'log1p_generic',
'broadcast_to',
'set_param_type',
'CheckTensor',


+ 0
- 7
mindspore/nn/probability/distribution/_utils/custom_ops.py View File

@@ -32,13 +32,6 @@ def exp_generic(input_x):
return exp(input_x)


def expm1_generic(input_x):
"""
Expm1 ops under GPU context.
"""
return exp_generic(input_x) - 1.0


def log_generic(input_x):
"""
Log op on Ascend is calculated as log(abs(x)).


+ 2
- 2
mindspore/nn/probability/distribution/gumbel.py View File

@@ -22,7 +22,7 @@ import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from .transformed_distribution import TransformedDistribution
from ._utils.utils import check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic

class Gumbel(TransformedDistribution):
"""
@@ -120,7 +120,7 @@ class Gumbel(TransformedDistribution):
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.lgamma = nn.LGamma()
self.log = log_generic


+ 2
- 2
mindspore/nn/probability/distribution/log_normal.py View File

@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from ._utils.utils import check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic

class LogNormal(msd.TransformedDistribution):
"""
@@ -146,7 +146,7 @@ class LogNormal(msd.TransformedDistribution):

#ops needed for the class
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.const = P.ScalarToArray()
self.erf = P.Erf()


+ 3
- 3
mindspore/nn/probability/distribution/logistic.py View File

@@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from ._utils.custom_ops import exp_generic, log_generic


class Logistic(Distribution):
@@ -124,11 +124,11 @@ class Logistic(Distribution):
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.less = P.Less()
self.log = log_generic
self.log1p = log1p_generic
self.log1p = P.Log1p()
self.logicalor = P.LogicalOr()
self.erf = P.Erf()
self.greater = P.Greater()


+ 2
- 2
mindspore/nn/probability/distribution/normal.py View File

@@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic


class Normal(Distribution):
@@ -137,7 +137,7 @@ class Normal(Distribution):

# ops needed for the class
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.erf = P.Erf()
self.squeeze = P.Squeeze(0)


+ 3
- 2
tests/ut/python/nn/probability/distribution/test_lognormal.py View File

@@ -175,15 +175,16 @@ class LogNormalBasics(nn.Cell):

def construct(self):
mean = self.n.mean()
sd = self.n.sd()
mode = self.n.mode()
entropy = self.n.entropy()
return mean + sd + mode + entropy
return mean + mode + entropy

def test_bascis():
"""
Test mean/sd/mode/entropy functionality of LogNormal.
"""
from mindspore import context
context.set_context(device_target="Ascend")
net = LogNormalBasics()
ans = net()
assert isinstance(ans, Tensor)


Loading…
Cancel
Save