|
|
@@ -496,6 +496,17 @@ def get_bprop_tensor_scatter_update(self): |
|
|
return bprop |
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.ScatterMax) |
|
|
|
|
|
def get_bprop_scatter_max(self): |
|
|
|
|
|
"""Generate bprop for ScatterMax""" |
|
|
|
|
|
gather = P.GatherV2() |
|
|
|
|
|
|
|
|
|
|
|
def bprop(x, indices, update, out, dout): |
|
|
|
|
|
return dout, zeros_like(indices), gather(dout, indices, 0) |
|
|
|
|
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Argmax) |
|
|
@bprop_getters.register(P.Argmax) |
|
|
def get_bprop_argmax(self): |
|
|
def get_bprop_argmax(self): |
|
|
"""Generate bprop for Argmax""" |
|
|
"""Generate bprop for Argmax""" |
|
|
|