| @@ -386,38 +386,6 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): | |||||
| @bprop_getters.register(P.Gather) | @bprop_getters.register(P.Gather) | ||||
| def get_bprop_gather(self): | |||||
| """Generate bprop for GatherV2""" | |||||
| def bprop(x, indices, axis, out, dout): | |||||
| orig_indices = indices | |||||
| if F.rank(dout) == 0: | |||||
| dout = P.ExpandDims()(dout, -1) | |||||
| if F.rank(indices) == 0: | |||||
| indices = P.ExpandDims()(indices, -1) | |||||
| x_shp = shape_op(x) | |||||
| ind_shp = shape_op(indices) | |||||
| out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) | |||||
| dout = reshape(dout, out_shp) | |||||
| x_shp = shape_op(x) | |||||
| out_shp = shape_op(dout) | |||||
| ind_shp = shape_op(indices) | |||||
| # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) | |||||
| perm_1 = _generate_shape_index(out_shp, ind_shp, axis) | |||||
| values_transpose = transpose(dout, perm_1) | |||||
| if -1 in shape_op(x): | |||||
| params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) | |||||
| else: | |||||
| params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) | |||||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||||
| params_grad = transpose(params_grad, perm_2) | |||||
| return params_grad, zeros_like(orig_indices), zeros_like(axis) | |||||
| return bprop | |||||
| @bprop_getters.register(P.GatherV2) | @bprop_getters.register(P.GatherV2) | ||||
| def get_bprop_gather_v2(self): | def get_bprop_gather_v2(self): | ||||
| """Generate bprop for GatherV2""" | """Generate bprop for GatherV2""" | ||||
| @@ -601,6 +569,7 @@ def get_bprop_range(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Pack) | |||||
| @bprop_getters.register(P.Stack) | @bprop_getters.register(P.Stack) | ||||
| def get_bprop_stack(self): | def get_bprop_stack(self): | ||||
| """Generate bprop for Stack""" | """Generate bprop for Stack""" | ||||