Browse Source

add bprop for ScatterMax

tags/v0.5.0-beta
yanzhenxiang2020 5 years ago
parent
commit
1edaf6dea7
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      mindspore/ops/_grad/grad_array_ops.py

+ 11
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self):
return bprop


@bprop_getters.register(P.ScatterMax)
def get_bprop_scatter_max(self):
"""Generate bprop for ScatterMax"""
gather = P.GatherV2()

def bprop(x, indices, update, out, dout):
return dout, zeros_like(indices), gather(dout, indices, 0)

return bprop


@bprop_getters.register(P.Argmax)
def get_bprop_argmax(self):
"""Generate bprop for Argmax"""


Loading…
Cancel
Save