|
|
|
@@ -200,14 +200,14 @@ def get_bprop_mirror_operator(self): |
|
|
|
float_one = F.scalar_cast(1.0, F.dtype(grad)) |
|
|
|
num = F.scalar_cast(dev_num, F.dtype(grad)) |
|
|
|
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) |
|
|
|
dx = (indices, grad, dout.dense_shape()) |
|
|
|
dx = IndexedSlices(indices, grad, dout.dense_shape()) |
|
|
|
else: |
|
|
|
if F.issubclass_(F.typeof(dout), mstype.tensor): |
|
|
|
dx = all_reduce(dout) |
|
|
|
else: |
|
|
|
indices = all_gather(dout.indices()) |
|
|
|
grad = all_gather(dout.values()) |
|
|
|
dx = (indices, grad, dout.dense_shape()) |
|
|
|
dx = IndexedSlices(indices, grad, dout.dense_shape()) |
|
|
|
|
|
|
|
return (dx,) |
|
|
|
return bprop |
|
|
|
|