|
|
|
@@ -425,11 +425,22 @@ def get_bprop_rsqrt(self): |
|
|
|
@bprop_getters.register(P.Reciprocal) |
|
|
|
def get_bprop_reciprocal(self): |
|
|
|
"""Grad definition for `Reciprocal` operation.""" |
|
|
|
reciprocal_grad = G.ReciprocalGrad() |
|
|
|
if self.target == "GPU": |
|
|
|
neg = P.Neg() |
|
|
|
mul = P.Mul() |
|
|
|
square = P.Square() |
|
|
|
reciprocal = P.Reciprocal() |
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
|
g = neg(reciprocal(square(x))) |
|
|
|
dx = mul(dout, g) |
|
|
|
return (dx,) |
|
|
|
else: |
|
|
|
reciprocal_grad = G.ReciprocalGrad() |
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
|
dx = reciprocal_grad(out, dout) |
|
|
|
return (dx,) |
|
|
|
def bprop(x, out, dout): |
|
|
|
dx = reciprocal_grad(out, dout) |
|
|
|
return (dx,) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|