Browse Source

Modify the supported dtype of IGamma to align with TensorFlow

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
f51fde14d6
1 changed files with 11 additions and 11 deletions
  1. +11
    -11
      mindspore/nn/layer/math.py

+ 11
- 11
mindspore/nn/layer/math.py View File

@@ -428,7 +428,7 @@ class DiGamma(Cell):
nan, real_result) nan, real_result)




eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16)
eps_fp64 = Tensor(np.finfo(np.float64).eps, mstype.float64)
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32) eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)


def _while_helper_func(cond, body, vals): def _while_helper_func(cond, body, vals):
@@ -447,8 +447,8 @@ def _IgammaSeries(ax, x, a, enabled):
dtype = P.DType() dtype = P.DType()
select = P.Select() select = P.Select()


if dtype(ax) == mstype.float16:
epsilon = eps_fp16
if dtype(ax) == mstype.float64:
epsilon = eps_fp64
else: else:
epsilon = eps_fp32 epsilon = eps_fp32


@@ -499,8 +499,8 @@ def _IgammacContinuedFraction(ax, x, a, enabled):
dtype = P.DType() dtype = P.DType()
select = P.Select() select = P.Select()


if dtype(ax) == mstype.float16:
epsilon = eps_fp16
if dtype(ax) == mstype.float64:
epsilon = eps_fp64
else: else:
epsilon = eps_fp32 epsilon = eps_fp32


@@ -619,9 +619,9 @@ class IGamma(Cell):
``Ascend`` ``Ascend``


Inputs: Inputs:
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
- **a** (Tensor) - The input tensor. With float32 or float64 data type. `a` should have
the same dtype with `x`. the same dtype with `x`.
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
- **x** (Tensor) - The input tensor. With float32 or float64 data type. `x` should have
the same dtype with `a`. the same dtype with `a`.


Outputs: Outputs:
@@ -639,7 +639,7 @@ class IGamma(Cell):
def __init__(self): def __init__(self):
super(IGamma, self).__init__() super(IGamma, self).__init__()
# const numbers # const numbers
self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16)
self.log_maxfloat64 = Tensor(np.log(np.finfo(np.float64).max), mstype.float64)
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32) self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)


# operations # operations
@@ -664,7 +664,7 @@ class IGamma(Cell):
def construct(self, a, x): def construct(self, a, x):
a_dtype = self.dtype(a) a_dtype = self.dtype(a)
x_dtype = self.dtype(x) x_dtype = self.dtype(x)
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name)
_check_input_dtype("input_a", a_dtype, [mstype.float32, mstype.float64], self.cls_name)
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
@@ -673,8 +673,8 @@ class IGamma(Cell):
a = boradcastto(a) a = boradcastto(a)
x = boradcastto(x) x = boradcastto(x)
x_is_zero = self.equal(x, 0) x_is_zero = self.equal(x, 0)
if a_dtype == mstype.float16:
log_maxfloat = self.log_maxfloat16
if a_dtype == mstype.float64:
log_maxfloat = self.log_maxfloat64
else: else:
log_maxfloat = self.log_maxfloat32 log_maxfloat = self.log_maxfloat32
underflow = self.less(ax, self.neg(log_maxfloat)) underflow = self.less(ax, self.neg(log_maxfloat))


Loading…
Cancel
Save