|
|
|
@@ -194,23 +194,26 @@ def get_bprop_tile(self): |
|
|
|
@bprop_getters.register(inner.EmbeddingLookup) |
|
|
|
def get_bprop_embedding_lookup(self): |
|
|
|
"""Generate bprop for EmbeddingLookup""" |
|
|
|
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') |
|
|
|
sub_op = P.Sub() |
|
|
|
reshape_op = P.Reshape() |
|
|
|
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') |
|
|
|
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): |
|
|
|
x_shp = shape_op(x) |
|
|
|
if reduce_scatter_flag is True: |
|
|
|
elu_grad = G.EmbeddingLookupCommGrad() |
|
|
|
actual_dout = elu_grad(dout, split_num) |
|
|
|
else: |
|
|
|
actual_dout = dout |
|
|
|
new_indices = host_sub(indices, offset) |
|
|
|
new_indices = sub_op(indices, offset) |
|
|
|
# Reshape the 'new_indices' |
|
|
|
new_indices_shape_changed = (size_op(new_indices),) |
|
|
|
new_indices = host_reshape(new_indices, new_indices_shape_changed) |
|
|
|
# Reshape the 'actual_dout' |
|
|
|
new_indices = reshape_op(new_indices, new_indices_shape_changed) |
|
|
|
x_shp_tail = x_shp[1:] |
|
|
|
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail |
|
|
|
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) |
|
|
|
if reduce_scatter_flag is True: |
|
|
|
# On host |
|
|
|
elu_grad = G.EmbeddingLookupCommGrad() |
|
|
|
actual_dout = elu_grad(dout, split_num) |
|
|
|
# Reshape the 'actual_dout' on host |
|
|
|
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) |
|
|
|
else: |
|
|
|
# Reshape the 'actual_dout' on device |
|
|
|
actual_dout = reshape_op(dout, actual_dout_shape_changed) |
|
|
|
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \ |
|
|
|
zeros_like(reduce_scatter_flag), zeros_like(split_num) |
|
|
|
return bprop_sparse |
|
|
|
|