|
|
|
@@ -433,7 +433,6 @@ class DiGamma(Cell): |
|
|
|
nan, real_result) |
|
|
|
|
|
|
|
|
|
|
|
eps_fp64 = Tensor(np.finfo(np.float64).eps, mstype.float64) |
|
|
|
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) |
|
|
|
|
|
|
|
def _while_helper_func(cond, body, vals): |
|
|
|
@@ -452,10 +451,8 @@ def _IgammaSeries(ax, x, a, enabled): |
|
|
|
dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
if dtype(ax) == mstype.float64: |
|
|
|
epsilon = eps_fp64 |
|
|
|
else: |
|
|
|
epsilon = eps_fp32 |
|
|
|
# If more data types are supported, this epsilon need to be selected. |
|
|
|
epsilon = eps_fp32 |
|
|
|
|
|
|
|
def cond(vals): |
|
|
|
enabled = vals[0] |
|
|
|
@@ -504,10 +501,8 @@ def _IgammacContinuedFraction(ax, x, a, enabled): |
|
|
|
dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
if dtype(ax) == mstype.float64: |
|
|
|
epsilon = eps_fp64 |
|
|
|
else: |
|
|
|
epsilon = eps_fp32 |
|
|
|
# If more data types are supported, this epsilon need to be selected. |
|
|
|
epsilon = eps_fp32 |
|
|
|
|
|
|
|
def cond(vals): |
|
|
|
enabled = vals[0] |
|
|
|
@@ -624,9 +619,9 @@ class IGamma(Cell): |
|
|
|
``Ascend`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **a** (Tensor) - The input tensor. With float32 or float64 data type. `a` should have |
|
|
|
- **a** (Tensor) - The input tensor. With float32 data type. `a` should have |
|
|
|
the same dtype with `x`. |
|
|
|
- **x** (Tensor) - The input tensor. With float32 or float64 data type. `x` should have |
|
|
|
- **x** (Tensor) - The input tensor. With float32 data type. `x` should have |
|
|
|
the same dtype with `a`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
@@ -644,7 +639,7 @@ class IGamma(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(IGamma, self).__init__() |
|
|
|
# const numbers |
|
|
|
self.log_maxfloat64 = Tensor(np.log(np.finfo(np.float64).max), mstype.float64) |
|
|
|
# If more data types are supported, this float max value need to be selected. |
|
|
|
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) |
|
|
|
|
|
|
|
# operations |
|
|
|
@@ -669,7 +664,7 @@ class IGamma(Cell): |
|
|
|
def construct(self, a, x): |
|
|
|
a_dtype = self.dtype(a) |
|
|
|
x_dtype = self.dtype(x) |
|
|
|
_check_input_dtype("input_a", a_dtype, [mstype.float32, mstype.float64], self.cls_name) |
|
|
|
_check_input_dtype("input_a", a_dtype, [mstype.float32], self.cls_name) |
|
|
|
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) |
|
|
|
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) |
|
|
|
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) |
|
|
|
@@ -680,10 +675,7 @@ class IGamma(Cell): |
|
|
|
x = boradcastto(x) |
|
|
|
y = boradcastto(y) |
|
|
|
x_is_zero = self.equal(x, 0) |
|
|
|
if a_dtype == mstype.float64: |
|
|
|
log_maxfloat = self.log_maxfloat64 |
|
|
|
else: |
|
|
|
log_maxfloat = self.log_maxfloat32 |
|
|
|
log_maxfloat = self.log_maxfloat32 |
|
|
|
underflow = self.less(ax, self.neg(log_maxfloat)) |
|
|
|
ax = self.exp(ax) |
|
|
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow)) |
|
|
|
|