From 50855e8fcf375b0d303c344d20cc22d298ded67f Mon Sep 17 00:00:00 2001 From: liangchenghui Date: Sat, 21 Nov 2020 16:11:34 +0800 Subject: [PATCH] Fix GatherV2 operator bprop shape problem. --- mindspore/ops/_grad/grad_array_ops.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index c08813aa7c..85fe9940d3 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -351,15 +351,30 @@ def _generate_inverse_index(x_shape, axis): return perm +@constexpr +def _regenerate_output_shape(x_shp, ind_shp, axis): + rank = len(x_shp) + if axis < 0: + axis += rank + out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:] + return out_shape + + @bprop_getters.register(P.GatherV2) def get_bprop_gather_v2(self): """Generate bprop for GatherV2""" def bprop(x, indices, axis, out, dout): + orig_indices = indices if F.rank(dout) == 0: dout = P.ExpandDims()(dout, -1) if F.rank(indices) == 0: indices = P.ExpandDims()(indices, -1) + x_shp = shape_op(x) + ind_shp = shape_op(indices) + out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) + dout = reshape(dout, out_shp) + x_shp = shape_op(x) out_shp = shape_op(dout) ind_shp = shape_op(indices) @@ -373,7 +388,7 @@ def get_bprop_gather_v2(self): # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) perm_2 = _generate_inverse_index(x_shp, axis) params_grad = transpose(params_grad, perm_2) - return params_grad, zeros_like(indices), zeros_like(axis) + return params_grad, zeros_like(orig_indices), zeros_like(axis) return bprop