| @@ -428,7 +428,7 @@ class DiGamma(Cell): | |||||
| nan, real_result) | nan, real_result) | ||||
| eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16) | |||||
| eps_fp64 = Tensor(np.finfo(np.float64).eps, mstype.float64) | |||||
| eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) | eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) | ||||
| def _while_helper_func(cond, body, vals): | def _while_helper_func(cond, body, vals): | ||||
| @@ -447,8 +447,8 @@ def _IgammaSeries(ax, x, a, enabled): | |||||
| dtype = P.DType() | dtype = P.DType() | ||||
| select = P.Select() | select = P.Select() | ||||
| if dtype(ax) == mstype.float16: | |||||
| epsilon = eps_fp16 | |||||
| if dtype(ax) == mstype.float64: | |||||
| epsilon = eps_fp64 | |||||
| else: | else: | ||||
| epsilon = eps_fp32 | epsilon = eps_fp32 | ||||
| @@ -499,8 +499,8 @@ def _IgammacContinuedFraction(ax, x, a, enabled): | |||||
| dtype = P.DType() | dtype = P.DType() | ||||
| select = P.Select() | select = P.Select() | ||||
| if dtype(ax) == mstype.float16: | |||||
| epsilon = eps_fp16 | |||||
| if dtype(ax) == mstype.float64: | |||||
| epsilon = eps_fp64 | |||||
| else: | else: | ||||
| epsilon = eps_fp32 | epsilon = eps_fp32 | ||||
| @@ -619,9 +619,9 @@ class IGamma(Cell): | |||||
| ``Ascend`` | ``Ascend`` | ||||
| Inputs: | Inputs: | ||||
| - **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have | |||||
| - **a** (Tensor) - The input tensor. With float32 or float64 data type. `a` should have | |||||
| the same dtype with `x`. | the same dtype with `x`. | ||||
| - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have | |||||
| - **x** (Tensor) - The input tensor. With float32 or float64 data type. `x` should have | |||||
| the same dtype with `a`. | the same dtype with `a`. | ||||
| Outputs: | Outputs: | ||||
| @@ -639,7 +639,7 @@ class IGamma(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(IGamma, self).__init__() | super(IGamma, self).__init__() | ||||
| # const numbers | # const numbers | ||||
| self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16) | |||||
| self.log_maxfloat64 = Tensor(np.log(np.finfo(np.float64).max), mstype.float64) | |||||
| self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) | self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) | ||||
| # operations | # operations | ||||
| @@ -664,7 +664,7 @@ class IGamma(Cell): | |||||
| def construct(self, a, x): | def construct(self, a, x): | ||||
| a_dtype = self.dtype(a) | a_dtype = self.dtype(a) | ||||
| x_dtype = self.dtype(x) | x_dtype = self.dtype(x) | ||||
| _check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name) | |||||
| _check_input_dtype("input_a", a_dtype, [mstype.float32, mstype.float64], self.cls_name) | |||||
| _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) | _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) | ||||
| domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) | domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) | ||||
| use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) | use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) | ||||
| @@ -673,8 +673,8 @@ class IGamma(Cell): | |||||
| a = boradcastto(a) | a = boradcastto(a) | ||||
| x = boradcastto(x) | x = boradcastto(x) | ||||
| x_is_zero = self.equal(x, 0) | x_is_zero = self.equal(x, 0) | ||||
| if a_dtype == mstype.float16: | |||||
| log_maxfloat = self.log_maxfloat16 | |||||
| if a_dtype == mstype.float64: | |||||
| log_maxfloat = self.log_maxfloat64 | |||||
| else: | else: | ||||
| log_maxfloat = self.log_maxfloat32 | log_maxfloat = self.log_maxfloat32 | ||||
| underflow = self.less(ax, self.neg(log_maxfloat)) | underflow = self.less(ax, self.neg(log_maxfloat)) | ||||