|
|
|
@@ -368,8 +368,8 @@ def get_bprop_sparse_gather_v2(self): |
|
|
|
x_tail_shp = x_shp[1:] |
|
|
|
values_shape = indices_size + x_tail_shp |
|
|
|
values = reshape(dout, values_shape) |
|
|
|
indices = reshape(indices, indices_size) |
|
|
|
return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis) |
|
|
|
indices_new = reshape(indices, indices_size) |
|
|
|
return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis) |
|
|
|
if F.rank(dout) == 0: |
|
|
|
dout = P.ExpandDims()(dout, -1) |
|
|
|
if F.rank(indices) == 0: |
|
|
|
|