|
|
|
@@ -25,7 +25,12 @@ from ...common import dtype as mstype |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments'] |
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'IGamma', 'MatMul', 'Moments'] |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): |
|
|
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) |
|
|
|
|
|
|
|
|
|
|
|
class ReduceLogSumExp(Cell): |
|
|
|
@@ -43,7 +48,7 @@ class ReduceLogSumExp(Cell): |
|
|
|
Default : False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor[Number]) - The input tensor. With float16 or float32 data type. |
|
|
|
- **input_x** (Tensor) - The input tensor. With float16 or float32 data type. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, has the same dtype as the `input_x`. |
|
|
|
@@ -213,7 +218,7 @@ class LGamma(Cell): |
|
|
|
when x = +/- inf, return +inf |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. |
|
|
|
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, has the same shape and dtype as the `input_x`. |
|
|
|
@@ -267,7 +272,7 @@ class LGamma(Cell): |
|
|
|
|
|
|
|
def construct(self, input_x): |
|
|
|
input_dtype = self.dtype(input_x) |
|
|
|
check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma") |
|
|
|
_check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name) |
|
|
|
infinity = self.fill(input_dtype, self.shape(input_x), self.inf) |
|
|
|
|
|
|
|
need_to_reflect = self.less(input_x, 0.5) |
|
|
|
@@ -307,6 +312,260 @@ class LGamma(Cell): |
|
|
|
return self.select(self.isfinite(input_x), result, infinity) |
|
|
|
|
|
|
|
|
|
|
|
eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16) |
|
|
|
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) |
|
|
|
|
|
|
|
def _while_helper_func(cond, body, vals): |
|
|
|
while cond(vals).any(): |
|
|
|
vals = body(vals) |
|
|
|
return vals |
|
|
|
|
|
|
|
|
|
|
|
def _IgammaSeries(ax, x, a, enabled): |
|
|
|
"""Helper function for computing Igamma using a power series.""" |
|
|
|
|
|
|
|
logicaland = P.LogicalAnd() |
|
|
|
greater = P.Greater() |
|
|
|
fill = P.Fill() |
|
|
|
shape = P.Shape() |
|
|
|
dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
if dtype(ax) == mstype.float16: |
|
|
|
epsilon = eps_fp16 |
|
|
|
else: |
|
|
|
epsilon = eps_fp32 |
|
|
|
|
|
|
|
def cond(vals): |
|
|
|
enabled = vals[0] |
|
|
|
return enabled |
|
|
|
|
|
|
|
def body(vals): |
|
|
|
enabled = vals[0] |
|
|
|
r = vals[1] |
|
|
|
c = vals[2] |
|
|
|
ans = vals[3] |
|
|
|
x = vals[4] |
|
|
|
dc_da = vals[5] |
|
|
|
dans_da = vals[6] |
|
|
|
|
|
|
|
r = r + 1 |
|
|
|
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r) |
|
|
|
dans_da = dans_da + dc_da |
|
|
|
c = c * (x / r) |
|
|
|
ans = ans + c |
|
|
|
conditional = logicaland(enabled, greater(c / ans, epsilon)) |
|
|
|
|
|
|
|
return (conditional, select(enabled, r, vals[1]), |
|
|
|
select(enabled, c, vals[2]), select(enabled, ans, vals[3]), |
|
|
|
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]), |
|
|
|
select(enabled, dans_da, vals[6])) |
|
|
|
|
|
|
|
ones = fill(dtype(a), shape(a), 1) |
|
|
|
zeros = fill(dtype(a), shape(a), 0) |
|
|
|
vals = (enabled, a, ones, ones, x, zeros, zeros) |
|
|
|
|
|
|
|
vals = _while_helper_func(cond, body, vals) |
|
|
|
ans = vals[3] |
|
|
|
return (ans * ax) / a |
|
|
|
|
|
|
|
|
|
|
|
def _IgammacContinuedFraction(ax, x, a, enabled): |
|
|
|
"""Helper function for computing Igammac using a continued fraction.""" |
|
|
|
|
|
|
|
abs_x = P.Abs() |
|
|
|
logicaland = P.LogicalAnd() |
|
|
|
greater = P.Greater() |
|
|
|
less = P.Less() |
|
|
|
notequal = P.NotEqual() |
|
|
|
fill = P.Fill() |
|
|
|
shape = P.Shape() |
|
|
|
dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
if dtype(ax) == mstype.float16: |
|
|
|
epsilon = eps_fp16 |
|
|
|
else: |
|
|
|
epsilon = eps_fp32 |
|
|
|
|
|
|
|
def cond(vals): |
|
|
|
enabled = vals[0] |
|
|
|
c = vals[5] |
|
|
|
return logicaland(less(c, 2000), enabled) |
|
|
|
|
|
|
|
def body(vals): |
|
|
|
enabled = vals[0] |
|
|
|
ans = vals[1] |
|
|
|
t = vals[2] |
|
|
|
y = vals[3] |
|
|
|
z = vals[4] |
|
|
|
c = vals[5] |
|
|
|
pkm1 = vals[6] |
|
|
|
qkm1 = vals[7] |
|
|
|
pkm2 = vals[8] |
|
|
|
qkm2 = vals[9] |
|
|
|
|
|
|
|
dpkm2_da = vals[10] |
|
|
|
dqkm2_da = vals[11] |
|
|
|
dpkm1_da = vals[12] |
|
|
|
dqkm1_da = vals[13] |
|
|
|
dans_da = vals[14] |
|
|
|
|
|
|
|
c = c + 1 |
|
|
|
y = y + 1 |
|
|
|
z = z + 2 |
|
|
|
|
|
|
|
yc = y * c |
|
|
|
pk = pkm1 * z - pkm2 * yc |
|
|
|
qk = qkm1 * z - qkm2 * yc |
|
|
|
qk_is_nonzero = notequal(qk, 0) |
|
|
|
r = pk / qk |
|
|
|
|
|
|
|
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1)) |
|
|
|
ans = select(qk_is_nonzero, r, ans) |
|
|
|
|
|
|
|
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c |
|
|
|
dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c |
|
|
|
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da) |
|
|
|
grad_conditional = select(qk_is_nonzero, |
|
|
|
abs_x(dans_da_new - dans_da), |
|
|
|
fill(dtype(dans_da), shape(dans_da), 1)) |
|
|
|
|
|
|
|
pkm2 = pkm1 |
|
|
|
pkm1 = pk |
|
|
|
qkm2 = qkm1 |
|
|
|
qkm1 = qk |
|
|
|
|
|
|
|
dpkm2_da = dpkm1_da |
|
|
|
dqkm2_da = dqkm1_da |
|
|
|
dpkm1_da = dpk_da |
|
|
|
dqkm1_da = dqk_da |
|
|
|
|
|
|
|
rescale = greater(abs_x(pk), 1 / epsilon) |
|
|
|
pkm2 = select(rescale, pkm2 * epsilon, pkm2) |
|
|
|
pkm1 = select(rescale, pkm1 * epsilon, pkm1) |
|
|
|
qkm2 = select(rescale, qkm2 * epsilon, qkm2) |
|
|
|
qkm1 = select(rescale, qkm1 * epsilon, qkm1) |
|
|
|
|
|
|
|
dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da) |
|
|
|
dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da) |
|
|
|
dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da) |
|
|
|
dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da) |
|
|
|
|
|
|
|
conditional = logicaland(enabled, greater(grad_conditional, epsilon)) |
|
|
|
|
|
|
|
return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]), |
|
|
|
select(enabled, y, vals[3]), select(enabled, z, vals[4]), |
|
|
|
c, select(enabled, pkm1, vals[6]), |
|
|
|
select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]), |
|
|
|
select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]), |
|
|
|
select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]), |
|
|
|
select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14])) |
|
|
|
|
|
|
|
y = 1 - a |
|
|
|
z = x + y + 1 |
|
|
|
c = fill(dtype(x), shape(x), 0) |
|
|
|
pkm2 = fill(dtype(x), shape(x), 1) |
|
|
|
qkm2 = x |
|
|
|
pkm1 = x + 1 |
|
|
|
qkm1 = z * x |
|
|
|
ans = pkm1 / qkm1 |
|
|
|
t = fill(dtype(x), shape(x), 1) |
|
|
|
dpkm2_da = fill(dtype(x), shape(x), 0) |
|
|
|
dqkm2_da = fill(dtype(x), shape(x), 0) |
|
|
|
dpkm1_da = fill(dtype(x), shape(x), 0) |
|
|
|
dqkm1_da = -x |
|
|
|
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1 |
|
|
|
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da) |
|
|
|
vals = _while_helper_func(cond, body, vals) |
|
|
|
ans = vals[1] |
|
|
|
return ans * ax |
|
|
|
|
|
|
|
|
|
|
|
class IGamma(Cell): |
|
|
|
r""" |
|
|
|
Calculate lower regularized incomplete Gamma function. |
|
|
|
The lower regularized incomplete Gamma function is defined as: |
|
|
|
|
|
|
|
.. math:: |
|
|
|
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) |
|
|
|
|
|
|
|
where |
|
|
|
|
|
|
|
.. math:: |
|
|
|
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt |
|
|
|
|
|
|
|
is the lower incomplete Gamma function. |
|
|
|
|
|
|
|
Above :math:`Q(a, x)` is the upper regularized complete Gamma function. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have |
|
|
|
the same dtype with `x`. |
|
|
|
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have |
|
|
|
the same dtype with `a`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, has the same dtype as `a` and `x`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) |
|
|
|
>>> input_x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) |
|
|
|
>>> igamma = nn.IGamma() |
|
|
|
>>> output = igamma(input_a, input_x) |
|
|
|
>>> print (output) |
|
|
|
[0.593994 0.35276785 0.21486944 0.13337152] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(IGamma, self).__init__() |
|
|
|
# const numbers |
|
|
|
self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16) |
|
|
|
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) |
|
|
|
|
|
|
|
# operations |
|
|
|
self.logicaland = P.LogicalAnd() |
|
|
|
self.logicalor = P.LogicalOr() |
|
|
|
self.logicalnot = P.LogicalNot() |
|
|
|
self.equal = P.Equal() |
|
|
|
self.greater = P.Greater() |
|
|
|
self.less = P.Less() |
|
|
|
self.neg = P.Neg() |
|
|
|
self.log = P.Log() |
|
|
|
self.exp = P.Exp() |
|
|
|
self.select = P.Select() |
|
|
|
self.zeroslike = P.ZerosLike() |
|
|
|
self.fill = P.Fill() |
|
|
|
self.shape = P.Shape() |
|
|
|
self.dtype = P.DType() |
|
|
|
self.lgamma = LGamma() |
|
|
|
self.const = P.ScalarToArray() |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, a, x): |
|
|
|
a_dtype = self.dtype(a) |
|
|
|
x_dtype = self.dtype(x) |
|
|
|
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name) |
|
|
|
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) |
|
|
|
x_is_zero = self.equal(x, 0) |
|
|
|
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) |
|
|
|
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) |
|
|
|
ax = a * self.log(x) - x - self.lgamma(a) |
|
|
|
if a_dtype == mstype.float16: |
|
|
|
log_maxfloat = self.log_maxfloat16 |
|
|
|
else: |
|
|
|
log_maxfloat = self.log_maxfloat32 |
|
|
|
underflow = self.less(ax, self.neg(log_maxfloat)) |
|
|
|
ax = self.exp(ax) |
|
|
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow)) |
|
|
|
output = self.select(use_igammac, |
|
|
|
1 - _IgammacContinuedFraction(ax, x, a, self.logicaland(enabled, use_igammac)), |
|
|
|
_IgammaSeries(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac)))) |
|
|
|
output = self.select(x_is_zero, self.zeroslike(output), output) |
|
|
|
output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def get_broadcast_matmul_shape(x_shape, y_shape): |
|
|
|
"""get broadcast_matmul shape""" |
|
|
|
@@ -453,11 +712,6 @@ class MatMul(Cell): |
|
|
|
return matmul_broadcast |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): |
|
|
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) |
|
|
|
|
|
|
|
|
|
|
|
class Moments(Cell): |
|
|
|
""" |
|
|
|
Calculate the mean and variance of `x`. |
|
|
|
|