Browse Source

add pack bprop

tags/v1.2.0-rc1
jinyaohui 4 years ago
parent
commit
eb97093f8b
1 changed files with 1 additions and 32 deletions
  1. +1
    -32
      mindspore/ops/_grad/grad_array_ops.py

+ 1
- 32
mindspore/ops/_grad/grad_array_ops.py View File

@@ -386,38 +386,6 @@ def _regenerate_output_shape(x_shp, ind_shp, axis):


@bprop_getters.register(P.Gather)
def get_bprop_gather(self):
"""Generate bprop for GatherV2"""

def bprop(x, indices, axis, out, dout):
orig_indices = indices
if F.rank(dout) == 0:
dout = P.ExpandDims()(dout, -1)
if F.rank(indices) == 0:
indices = P.ExpandDims()(indices, -1)
x_shp = shape_op(x)
ind_shp = shape_op(indices)
out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
dout = reshape(dout, out_shp)

x_shp = shape_op(x)
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
values_transpose = transpose(dout, perm_1)
if -1 in shape_op(x):
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
else:
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2 = _generate_inverse_index(x_shp, axis)
params_grad = transpose(params_grad, perm_2)
return params_grad, zeros_like(orig_indices), zeros_like(axis)

return bprop


@bprop_getters.register(P.GatherV2)
def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2"""
@@ -601,6 +569,7 @@ def get_bprop_range(self):
return bprop


@bprop_getters.register(P.Pack)
@bprop_getters.register(P.Stack)
def get_bprop_stack(self):
"""Generate bprop for Stack"""


Loading…
Cancel
Save