|
|
|
@@ -397,6 +397,22 @@ def get_bprop_xlogy(self): |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
@bprop_getters.register(P.SquareSumAll) |
|
|
|
def get_bprop_square_sum_all(self): |
|
|
|
"""Grad definition for `Square` operation.""" |
|
|
|
mul_func = P.Mul() |
|
|
|
fill_func = P.Fill() |
|
|
|
dtype = P.DType() |
|
|
|
|
|
|
|
def bprop(x, y, out, dout): |
|
|
|
temp_x = mul_func(dout[0], x) |
|
|
|
temp_y = mul_func(dout[1], y) |
|
|
|
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x) |
|
|
|
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y) |
|
|
|
return (dx, dy) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Sqrt) |
|
|
|
def get_bprop_sqrt(self): |
|
|
|
|