|
|
|
@@ -15,7 +15,6 @@ |
|
|
|
"""Softplus Bijector""" |
|
|
|
import numpy as np |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.nn.layer.activation import LogSigmoid |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from ..distribution._utils.utils import cast_to_tensor |
|
|
|
@@ -71,6 +70,7 @@ class Softplus(Bijector): |
|
|
|
self.log = log_generic |
|
|
|
self.expm1 = expm1_generic |
|
|
|
self.abs = P.Abs() |
|
|
|
self.dtypeop = P.DType() |
|
|
|
self.fill = P.Fill() |
|
|
|
self.greater = P.Greater() |
|
|
|
self.less = P.Less() |
|
|
|
@@ -90,7 +90,7 @@ class Softplus(Bijector): |
|
|
|
too_large = self.greater(x, -self.threshold) |
|
|
|
too_small_value = self.exp(x) |
|
|
|
too_large_value = x |
|
|
|
ones = self.fill(mstype.float32, self.shape(x), 1.0) |
|
|
|
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) |
|
|
|
too_small_or_too_large = self.logicalor(too_small, too_large) |
|
|
|
x = self.select(too_small_or_too_large, ones, x) |
|
|
|
y = self.log(self.exp(x) + 1.0) |
|
|
|
@@ -106,7 +106,7 @@ class Softplus(Bijector): |
|
|
|
too_large = self.greater(x, -self.threshold) |
|
|
|
too_small_value = self.log(x) |
|
|
|
too_large_value = x |
|
|
|
ones = self.fill(mstype.float32, self.shape(x), 1.0) |
|
|
|
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) |
|
|
|
too_small_or_too_large = self.logicalor(too_small, too_large) |
|
|
|
x = self.select(too_small_or_too_large, ones, x) |
|
|
|
y = x + self.log(self.abs(self.expm1(-x))) |
|
|
|
|