Merge pull request !4004 from fangzehua/squaresumalltags/v0.7.0-beta
| @@ -397,6 +397,22 @@ def get_bprop_xlogy(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.SquareSumAll) | |||||
| def get_bprop_square_sum_all(self): | |||||
| """Grad definition for `Square` operation.""" | |||||
| mul_func = P.Mul() | |||||
| fill_func = P.Fill() | |||||
| dtype = P.DType() | |||||
| def bprop(x, y, out, dout): | |||||
| temp_x = mul_func(dout[0], x) | |||||
| temp_y = mul_func(dout[1], y) | |||||
| dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x) | |||||
| dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y) | |||||
| return (dx, dy) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Sqrt) | @bprop_getters.register(P.Sqrt) | ||||
| def get_bprop_sqrt(self): | def get_bprop_sqrt(self): | ||||
| @@ -1188,7 +1188,8 @@ test_case_math_ops = [ | |||||
| 'block': P.SquareSumAll(), | 'block': P.SquareSumAll(), | ||||
| 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), | ||||
| Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], | Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | |||||
| 'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)), | |||||
| Tensor(np.array(0.1).astype(np.float32))]}), | |||||
| ('Cos', { | ('Cos', { | ||||
| 'block': P.Cos(), | 'block': P.Cos(), | ||||
| 'desc_inputs': [[2, 3]], | 'desc_inputs': [[2, 3]], | ||||