Merge pull request !1986 from jiangjinsheng/issue_fix2tags/v0.5.0-beta
| @@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self): | |||||
| reciprocal = P.Reciprocal() | reciprocal = P.Reciprocal() | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| dtype = P.DType() | dtype = P.DType() | ||||
| abs_ops = P.Abs() | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| zeros = zeros_like(x) | zeros = zeros_like(x) | ||||
| np_eps = const_utils.get_np_eps(dtype(x)) | np_eps = const_utils.get_np_eps(dtype(x)) | ||||
| eps = cast(np_eps, dtype(x)) | eps = cast(np_eps, dtype(x)) | ||||
| x_is_valid = less(eps, x) | |||||
| x_is_valid = less(eps, abs_ops(x)) | |||||
| x_safe = select(x_is_valid, x, eps + zeros) | x_safe = select(x_is_valid, x, eps + zeros) | ||||
| tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe)) | |||||
| dx = select(x_is_valid, tmp, 0.5 + zeros) | |||||
| tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe)) | |||||
| dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer): | |||||
| def __infer__(self, x, axis): | def __infer__(self, x, axis): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| x_shp = x['shape'] | x_shp = x['shape'] | ||||
| if axis['value'] is None: | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | |||||
| validator.check_value_type('axis', axis['value'], [int], cls_name) | validator.check_value_type('axis', axis['value'], [int], cls_name) | ||||
| valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | ||||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | ||||
| @@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer): | |||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| def infer_value(self, x, axis): | |||||
| if axis is None: | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | |||||
| class AddN(PrimitiveWithInfer): | class AddN(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||||
| def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): | def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): | ||||
| if decay is None or momentum is None or epsilon is None: | if decay is None or momentum is None or epsilon is None: | ||||
| raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") | raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") | ||||
| if not self.is_ge and self.is_d: | |||||
| return None, None, None | |||||
| return None | |||||
| class ApplyCenteredRMSProp(PrimitiveWithInfer): | class ApplyCenteredRMSProp(PrimitiveWithInfer): | ||||