|
|
|
@@ -386,38 +386,6 @@ def _regenerate_output_shape(x_shp, ind_shp, axis): |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Gather) |
|
|
|
def get_bprop_gather(self): |
|
|
|
"""Generate bprop for GatherV2""" |
|
|
|
|
|
|
|
def bprop(x, indices, axis, out, dout): |
|
|
|
orig_indices = indices |
|
|
|
if F.rank(dout) == 0: |
|
|
|
dout = P.ExpandDims()(dout, -1) |
|
|
|
if F.rank(indices) == 0: |
|
|
|
indices = P.ExpandDims()(indices, -1) |
|
|
|
x_shp = shape_op(x) |
|
|
|
ind_shp = shape_op(indices) |
|
|
|
out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) |
|
|
|
dout = reshape(dout, out_shp) |
|
|
|
|
|
|
|
x_shp = shape_op(x) |
|
|
|
out_shp = shape_op(dout) |
|
|
|
ind_shp = shape_op(indices) |
|
|
|
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2) |
|
|
|
perm_1 = _generate_shape_index(out_shp, ind_shp, axis) |
|
|
|
values_transpose = transpose(dout, perm_1) |
|
|
|
if -1 in shape_op(x): |
|
|
|
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) |
|
|
|
else: |
|
|
|
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) |
|
|
|
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0) |
|
|
|
perm_2 = _generate_inverse_index(x_shp, axis) |
|
|
|
params_grad = transpose(params_grad, perm_2) |
|
|
|
return params_grad, zeros_like(orig_indices), zeros_like(axis) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.GatherV2) |
|
|
|
def get_bprop_gather_v2(self): |
|
|
|
"""Generate bprop for GatherV2""" |
|
|
|
@@ -601,6 +569,7 @@ def get_bprop_range(self): |
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Pack) |
|
|
|
@bprop_getters.register(P.Stack) |
|
|
|
def get_bprop_stack(self): |
|
|
|
"""Generate bprop for Stack""" |
|
|
|
|