Browse Source

!9207 Fix maxgradgrad minimumgradgrad relugradgrad

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f42adabb57
1 changed files with 14 additions and 3 deletions
  1. +14
    -3
      mindspore/ops/_grad/grad_implementations.py

+ 14
- 3
mindspore/ops/_grad/grad_implementations.py View File

@@ -14,29 +14,40 @@
# ============================================================================

"""bprop primitives"""
from ..operations import _grad_ops as G
from .. import functional as F
from .. import operations as P
from ..composite import multitype_ops as C
from .grad_base import bprops

get_dtype = P.DType()
# Unused parameters are placeholders.


@bprops.register("MaximumGrad")
def bprop_maximum_grad_grad(x, y, z, out, dout):
"""Backpropagator for primitive `MaximumGrad`."""
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z)
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
dz = out0 * dout[0] + out1 * dout[1]
return F.zeros_like(x), F.zeros_like(y), dz


@bprops.register("MinimumGrad")
def bprop_minimum_grad_grad(x, y, z, out, dout):
"""Backpropagator for primitive `MinimumGrad`."""
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z)
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
dz = out0 * dout[0] + out1 * dout[1]
return F.zeros_like(x), F.zeros_like(y), dz


@bprops.register("ReluGrad")
def bprop_relu_grad_grad(x, y, out, dout):
"""Backpropagator for primitive `ReluGrad`."""
return F.zeros_like(x), F.zeros_like(y)
input_grad = G.ReluGrad()
dy = input_grad(dout, y)
return dy, F.zeros_like(y)


@bprops.register("scalar_add")


Loading…
Cancel
Save