|
|
|
@@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self): |
|
|
|
# Reshape the 'new_indices' |
|
|
|
new_indices_shape_changed = (size_op(new_indices),) |
|
|
|
new_indices = reshape_op(new_indices, new_indices_shape_changed) |
|
|
|
x_shp_tail = x_shp[1:] |
|
|
|
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail |
|
|
|
actual_dout_shape_changed = new_indices_shape_changed |
|
|
|
if len(x_shp) > 1: |
|
|
|
actual_dout_shape_changed += x_shp[1:] |
|
|
|
actual_dout = reshape_op(dout, actual_dout_shape_changed) |
|
|
|
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) |
|
|
|
return bprop |
|
|
|
|