From 1edaf6dea7ff2ef9ccef966519a9d04fd7695eab Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Thu, 18 Jun 2020 11:04:56 +0800 Subject: [PATCH] add bprop for ScatterMax --- mindspore/ops/_grad/grad_array_ops.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 4bc2dc5a6b..f6a2e8bb29 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self): return bprop +@bprop_getters.register(P.ScatterMax) +def get_bprop_scatter_max(self): + """Generate bprop for ScatterMax""" + gather = P.GatherV2() + + def bprop(x, indices, update, out, dout): + return dout, zeros_like(indices), gather(dout, indices, 0) + + return bprop + + @bprop_getters.register(P.Argmax) def get_bprop_argmax(self): """Generate bprop for Argmax"""