Browse Source

!2385 [Auto parallel] Keep the 'Sub' in the bprob of EmbeddingLookup on device

Merge pull request !2385 from Xiaoda/6-keep-sub-on-device-in-bprob-of-embeddinglookup
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f96bf04b5f
1 changed files with 13 additions and 10 deletions
  1. +13
    -10
      mindspore/ops/_grad/grad_array_ops.py

+ 13
- 10
mindspore/ops/_grad/grad_array_ops.py View File

@@ -194,23 +194,26 @@ def get_bprop_tile(self):
@bprop_getters.register(inner.EmbeddingLookup)
def get_bprop_embedding_lookup(self):
"""Generate bprop for EmbeddingLookup"""
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
sub_op = P.Sub()
reshape_op = P.Reshape()
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
x_shp = shape_op(x)
if reduce_scatter_flag is True:
elu_grad = G.EmbeddingLookupCommGrad()
actual_dout = elu_grad(dout, split_num)
else:
actual_dout = dout
new_indices = host_sub(indices, offset)
new_indices = sub_op(indices, offset)
# Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),)
new_indices = host_reshape(new_indices, new_indices_shape_changed)
# Reshape the 'actual_dout'
new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
if reduce_scatter_flag is True:
# On host
elu_grad = G.EmbeddingLookupCommGrad()
actual_dout = elu_grad(dout, split_num)
# Reshape the 'actual_dout' on host
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
else:
# Reshape the 'actual_dout' on device
actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
zeros_like(reduce_scatter_flag), zeros_like(split_num)
return bprop_sparse


Loading…
Cancel
Save