Browse Source

Fix errors in log calculation logics

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
1c8eb9b15d
2 changed files with 108 additions and 26 deletions
  1. +10
    -4
      mindspore/nn/probability/distribution/_utils/custom_ops.py
  2. +98
    -22
      tests/ut/python/nn/distribution/test_bernoulli.py

+ 10
- 4
mindspore/nn/probability/distribution/_utils/custom_ops.py View File

@@ -15,24 +15,30 @@
"""Utitly functions to help distribution class.""" """Utitly functions to help distribution class."""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype


def log_by_step(input_x): def log_by_step(input_x):
""" """
Log op on Ascend is calculated as log(abs(x)). Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan. Fix this with putting negative values as nan.
""" """
select = P.Select()
log = P.Log() log = P.Log()
less = P.Less()
lessequal = P.LessEqual() lessequal = P.LessEqual()
fill = P.Fill() fill = P.Fill()
cast = P.Cast()
dtype = P.DType() dtype = P.DType()
shape = P.Shape() shape = P.Shape()
select = P.Select()


input_x = cast(input_x, mstype.float32)
nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0)
nonpos_x = lessequal(input_x, 0.0) nonpos_x = lessequal(input_x, 0.0)
log_x = log(input_x) log_x = log(input_x)
nan = fill(dtype(input_x), shape(input_x), np.nan)
result = select(nonpos_x, nan, log_x)
return result
result = select(nonpos_x, -inf, log_x)
return select(neg_x, nan, result)


def log1p_by_step(x): def log1p_by_step(x):
""" """


+ 98
- 22
tests/ut/python/nn/distribution/test_bernoulli.py View File

@@ -157,51 +157,127 @@ def test_cross_entropy():
ans = net(probs_b, probs_a) ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)


class BernoulliBasics(nn.Cell):
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)

def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2

def test_bernoulli_construct():
"""
Test probability function going through construct.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)

class BernoulliMean(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliBasics, self).__init__()
super(BernoulliMean, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)


def construct(self): def construct(self):
mean = self.b.mean() mean = self.b.mean()
return mean

def test_mean():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliMean()
ans = net()
assert isinstance(ans, Tensor)

class BernoulliSd(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliSd, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)

def construct(self):
sd = self.b.sd() sd = self.b.sd()
return sd

def test_sd():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliSd()
ans = net()
assert isinstance(ans, Tensor)

class BernoulliVar(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliVar, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)

def construct(self):
var = self.b.var() var = self.b.var()
return var

def test_var():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliVar()
ans = net()
assert isinstance(ans, Tensor)

class BernoulliMode(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliMode, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)

def construct(self):
mode = self.b.mode() mode = self.b.mode()
entropy = self.b.entropy()
return mean + sd + var + mode + entropy
return mode


def test_bascis():
def test_mode():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
""" """
net = BernoulliBasics()
net = BernoulliMode()
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)


class BernoulliConstruct(nn.Cell):
class BernoulliEntropy(nn.Cell):
""" """
Bernoulli distribution: going through construct.
Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
super(BernoulliEntropy, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)


def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def construct(self):
entropy = self.b.entropy()
return entropy


def test_bernoulli_construct():
def test_entropy():
""" """
Test probability function going through construct.
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
""" """
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
net = BernoulliEntropy()
ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)

Loading…
Cancel
Save