|
|
|
@@ -25,7 +25,7 @@ from ...common import dtype as mstype |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'MatMul', 'Moments'] |
|
|
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'DiGamma', 'IGamma', 'LBeta', 'MatMul', 'Moments'] |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
@@ -227,6 +227,9 @@ class LGamma(Cell): |
|
|
|
when x is an integer less or equal to 0, return +inf |
|
|
|
when x = +/- inf, return +inf |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported. |
|
|
|
|
|
|
|
@@ -346,6 +349,9 @@ class DiGamma(Cell): |
|
|
|
|
|
|
|
digamma(x) = digamma(1 - x) - pi * cot(pi * x) |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. |
|
|
|
|
|
|
|
@@ -609,6 +615,9 @@ class IGamma(Cell): |
|
|
|
|
|
|
|
Above :math:`Q(a, x)` is the upper regularized complete Gamma function. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have |
|
|
|
the same dtype with `x`. |
|
|
|
@@ -679,6 +688,102 @@ class IGamma(Cell): |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class LBeta(Cell): |
|
|
|
r""" |
|
|
|
This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y). |
|
|
|
|
|
|
|
The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation |
|
|
|
is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly |
|
|
|
decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling |
|
|
|
the large terms from the Striling analytically. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have |
|
|
|
the same dtype with `y`. |
|
|
|
- **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have |
|
|
|
the same dtype with `x`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, has the same dtype as `x` and `y`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) |
|
|
|
>>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32)) |
|
|
|
>>> lbeta = nn.LBeta() |
|
|
|
>>> output = lbeta(input_a, input_x) |
|
|
|
>>> print (output) |
|
|
|
[-1.7917596 -4.094345 -12.000229 -14.754799] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(LBeta, self).__init__() |
|
|
|
# const numbers |
|
|
|
self.log_2pi = np.log(2 * np.pi) |
|
|
|
self.minimax_coeff = [-0.165322962780713e-02, |
|
|
|
0.837308034031215e-03, |
|
|
|
-0.595202931351870e-03, |
|
|
|
0.793650666825390e-03, |
|
|
|
-0.277777777760991e-02, |
|
|
|
0.833333333333333e-01] |
|
|
|
|
|
|
|
# operations |
|
|
|
self.log = P.Log() |
|
|
|
self.log1p = P.Log1p() |
|
|
|
self.less = P.Less() |
|
|
|
self.select = P.Select() |
|
|
|
self.shape = P.Shape() |
|
|
|
self.dtype = P.DType() |
|
|
|
self.lgamma = LGamma() |
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
x_dtype = self.dtype(x) |
|
|
|
y_dtype = self.dtype(y) |
|
|
|
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) |
|
|
|
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name) |
|
|
|
x_plus_y = x + y |
|
|
|
boradcastto = P.BroadcastTo(self.shape(x_plus_y)) |
|
|
|
x = boradcastto(x) |
|
|
|
y = boradcastto(y) |
|
|
|
comp_less = self.less(x, y) |
|
|
|
x_min = self.select(comp_less, x, y) |
|
|
|
y_max = self.select(comp_less, y, x) |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _log_gamma_correction(x, minimax_coeff): |
|
|
|
inverse_x = 1. / x |
|
|
|
inverse_x_squared = inverse_x * inverse_x |
|
|
|
accum = minimax_coeff[0] |
|
|
|
for i in range(1, 6): |
|
|
|
accum = accum * inverse_x_squared + minimax_coeff[i] |
|
|
|
return accum * inverse_x |
|
|
|
|
|
|
|
log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff) |
|
|
|
log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff) |
|
|
|
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff) |
|
|
|
|
|
|
|
# Two large arguments case: y >= x >= 8. |
|
|
|
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \ |
|
|
|
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \ |
|
|
|
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max) |
|
|
|
|
|
|
|
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min |
|
|
|
correction = log_gamma_correction_y - log_gamma_correction_x_y |
|
|
|
log_gamma_difference_big_y = correction + cancelled_stirling |
|
|
|
|
|
|
|
# One large argument case: x < 8, y >= 8. |
|
|
|
log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y |
|
|
|
|
|
|
|
# Small arguments case: x <= y < 8. |
|
|
|
log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max) |
|
|
|
comp_xless8 = self.less(x_min, 8) |
|
|
|
comp_yless8 = self.less(y_max, 8) |
|
|
|
temp = self.select(comp_yless8, log_beta_small, log_beta_one_large) |
|
|
|
return self.select(comp_xless8, temp, log_beta_two_large) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def get_broadcast_matmul_shape(x_shape, y_shape): |
|
|
|
"""get broadcast_matmul shape""" |
|
|
|
|