diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 59a2f27d22..3b8ce7c1e2 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -377,6 +377,39 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): @bprop_getters.register(P.Gather) +def get_bprop_gather(self): + """Generate bprop for GatherV2""" + + def bprop(x, indices, axis, out, dout): + orig_indices = indices + if F.rank(dout) == 0: + dout = P.ExpandDims()(dout, -1) + if F.rank(indices) == 0: + indices = P.ExpandDims()(indices, -1) + x_shp = shape_op(x) + ind_shp = shape_op(indices) + out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) + dout = reshape(dout, out_shp) + + x_shp = shape_op(x) + out_shp = shape_op(dout) + ind_shp = shape_op(indices) + # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) + perm_1 = _generate_shape_index(out_shp, ind_shp, axis) + values_transpose = transpose(dout, perm_1) + if -1 in shape_op(x): + params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) + else: + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = transpose(params_grad, perm_2) + return params_grad, zeros_like(orig_indices), zeros_like(axis) + + return bprop + + +@bprop_getters.register(P.GatherV2) def get_bprop_gather_v2(self): """Generate bprop for GatherV2""" diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 715f30c6f2..9ac29d2f80 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -156,6 +156,16 @@ def bprop_batchmatmul(self): @bprop_getters.register(P.Add) +def get_bprop_add(self): + """Grad definition for `Add` operation.""" + + def bprop(x, y, out, dout): + return binop_grad_common(x, y, dout, dout) + + return bprop + + +@bprop_getters.register(P.TensorAdd) def get_bprop_tensor_add(self): """Grad definition for `Add` operation."""