Browse Source

!1986 fixed validator for CumSum

Merge pull request !1986 from jiangjinsheng/issue_fix2
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
af85b2cebf
3 changed files with 6 additions and 10 deletions
  1. +4
    -3
      mindspore/ops/_grad/grad_math_ops.py
  2. +2
    -4
      mindspore/ops/operations/math_ops.py
  3. +0
    -3
      mindspore/ops/operations/nn_ops.py

+ 4
- 3
mindspore/ops/_grad/grad_math_ops.py View File

@@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self):
reciprocal = P.Reciprocal()
cast = P.Cast()
dtype = P.DType()
abs_ops = P.Abs()

def bprop(x, out, dout):
zeros = zeros_like(x)
np_eps = const_utils.get_np_eps(dtype(x))
eps = cast(np_eps, dtype(x))
x_is_valid = less(eps, x)
x_is_valid = less(eps, abs_ops(x))
x_safe = select(x_is_valid, x, eps + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, 0.5 + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
return (dx,)
return bprop



+ 2
- 4
mindspore/ops/operations/math_ops.py View File

@@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
def __infer__(self, x, axis):
cls_name = self.name
x_shp = x['shape']
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
@@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}

def infer_value(self, x, axis):
if axis is None:
raise ValueError(f"For {self.name}, axis must be const.")


class AddN(PrimitiveWithInfer):
"""


+ 0
- 3
mindspore/ops/operations/nn_ops.py View File

@@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
if decay is None or momentum is None or epsilon is None:
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
if not self.is_ge and self.is_d:
return None, None, None
return None


class ApplyCenteredRMSProp(PrimitiveWithInfer):


Loading…
Cancel
Save