|
|
|
@@ -14,29 +14,40 @@ |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
"""bprop primitives""" |
|
|
|
from ..operations import _grad_ops as G |
|
|
|
from .. import functional as F |
|
|
|
from .. import operations as P |
|
|
|
from ..composite import multitype_ops as C |
|
|
|
from .grad_base import bprops |
|
|
|
|
|
|
|
get_dtype = P.DType() |
|
|
|
# Unused parameters are placeholders. |
|
|
|
|
|
|
|
|
|
|
|
@bprops.register("MaximumGrad") |
|
|
|
def bprop_maximum_grad_grad(x, y, z, out, dout): |
|
|
|
"""Backpropagator for primitive `MaximumGrad`.""" |
|
|
|
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z) |
|
|
|
out0 = F.cast(out[0] != 0, get_dtype(dout[0])) |
|
|
|
out1 = F.cast(out[1] != 0, get_dtype(dout[1])) |
|
|
|
dz = out0 * dout[0] + out1 * dout[1] |
|
|
|
return F.zeros_like(x), F.zeros_like(y), dz |
|
|
|
|
|
|
|
|
|
|
|
@bprops.register("MinimumGrad") |
|
|
|
def bprop_minimum_grad_grad(x, y, z, out, dout): |
|
|
|
"""Backpropagator for primitive `MinimumGrad`.""" |
|
|
|
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z) |
|
|
|
out0 = F.cast(out[0] != 0, get_dtype(dout[0])) |
|
|
|
out1 = F.cast(out[1] != 0, get_dtype(dout[1])) |
|
|
|
dz = out0 * dout[0] + out1 * dout[1] |
|
|
|
return F.zeros_like(x), F.zeros_like(y), dz |
|
|
|
|
|
|
|
|
|
|
|
@bprops.register("ReluGrad") |
|
|
|
def bprop_relu_grad_grad(x, y, out, dout): |
|
|
|
"""Backpropagator for primitive `ReluGrad`.""" |
|
|
|
return F.zeros_like(x), F.zeros_like(y) |
|
|
|
input_grad = G.ReluGrad() |
|
|
|
dy = input_grad(dout, y) |
|
|
|
return dy, F.zeros_like(y) |
|
|
|
|
|
|
|
|
|
|
|
@bprops.register("scalar_add") |
|
|
|
|