Browse Source

!4004 add squreasumall grad

Merge pull request !4004 from fangzehua/squaresumall
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c493859978
2 changed files with 18 additions and 1 deletions
  1. +16
    -0
      mindspore/ops/_grad/grad_math_ops.py
  2. +2
    -1
      tests/ut/python/ops/test_ops.py

+ 16
- 0
mindspore/ops/_grad/grad_math_ops.py View File

@@ -397,6 +397,22 @@ def get_bprop_xlogy(self):


return bprop return bprop


@bprop_getters.register(P.SquareSumAll)
def get_bprop_square_sum_all(self):
"""Grad definition for `Square` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
dtype = P.DType()

def bprop(x, y, out, dout):
temp_x = mul_func(dout[0], x)
temp_y = mul_func(dout[1], y)
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
return (dx, dy)

return bprop



@bprop_getters.register(P.Sqrt) @bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self): def get_bprop_sqrt(self):


+ 2
- 1
tests/ut/python/ops/test_ops.py View File

@@ -1188,7 +1188,8 @@ test_case_math_ops = [
'block': P.SquareSumAll(), 'block': P.SquareSumAll(),
'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)),
Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], Tensor(np.array([1, 1, 3, 7]).astype(np.float32))],
'skip': ['backward']}),
'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)),
Tensor(np.array(0.1).astype(np.float32))]}),
('Cos', { ('Cos', {
'block': P.Cos(), 'block': P.Cos(),
'desc_inputs': [[2, 3]], 'desc_inputs': [[2, 3]],


Loading…
Cancel
Save