|
|
|
@@ -433,7 +433,10 @@ def get_bprop_sparse_gather_v2(self): |
|
|
|
x_shp = shape_op(x) |
|
|
|
if axis == 0: |
|
|
|
indices_size = (size_op(indices),) |
|
|
|
x_tail_shp = x_shp[1:] |
|
|
|
if len(x_shp) <= 1: |
|
|
|
x_tail_shp = () |
|
|
|
else: |
|
|
|
x_tail_shp = x_shp[1:] |
|
|
|
values_shape = indices_size + x_tail_shp |
|
|
|
values = reshape(dout, values_shape) |
|
|
|
indices_new = reshape(indices, indices_size) |
|
|
|
|