|
|
@@ -16,6 +16,7 @@ |
|
|
"""Generate bprop for comm ops""" |
|
|
"""Generate bprop for comm ops""" |
|
|
import mindspore.common.dtype as mstype |
|
|
import mindspore.common.dtype as mstype |
|
|
from mindspore.ops import functional as F |
|
|
from mindspore.ops import functional as F |
|
|
|
|
|
from mindspore.communication import get_rank, get_group_size |
|
|
from .. import operations as P |
|
|
from .. import operations as P |
|
|
from ...common.tensor import RowTensor |
|
|
from ...common.tensor import RowTensor |
|
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like |
|
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like |
|
|
@@ -116,15 +117,27 @@ def get_bprop_broad_cast(self): |
|
|
@bprop_getters.register(AllGather) |
|
|
@bprop_getters.register(AllGather) |
|
|
def get_bprop_all_gather(self): |
|
|
def get_bprop_all_gather(self): |
|
|
"""Generate bprop for AllGather""" |
|
|
"""Generate bprop for AllGather""" |
|
|
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group) |
|
|
|
|
|
fusion = self.get_attr_dict()["fusion"] |
|
|
fusion = self.get_attr_dict()["fusion"] |
|
|
all_gather_grad.add_prim_attr("fusion", fusion) |
|
|
|
|
|
if self.instance_name: |
|
|
|
|
|
instance_name = "grad_" + self.instance_name |
|
|
|
|
|
all_gather_grad.set_prim_instance_name(instance_name) |
|
|
|
|
|
|
|
|
if fusion == 0: |
|
|
|
|
|
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group) |
|
|
|
|
|
if self.instance_name: |
|
|
|
|
|
instance_name = "grad_" + self.instance_name |
|
|
|
|
|
reduce_scatter.set_prim_instance_name(instance_name) |
|
|
|
|
|
else: |
|
|
|
|
|
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1) |
|
|
|
|
|
if self.instance_name: |
|
|
|
|
|
instance_name = "grad_" + self.instance_name |
|
|
|
|
|
all_reduce.set_prim_instance_name(instance_name) |
|
|
|
|
|
rank = get_rank(self.group) |
|
|
|
|
|
dev_num = get_group_size(self.group) |
|
|
|
|
|
split = P.Split(output_num=dev_num) |
|
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
def bprop(x, out, dout): |
|
|
dx = all_gather_grad(dout) |
|
|
|
|
|
|
|
|
if fusion == 0: |
|
|
|
|
|
dx = reduce_scatter(dout) |
|
|
|
|
|
else: |
|
|
|
|
|
grad = all_reduce(dout) |
|
|
|
|
|
dx = split(grad)[rank] |
|
|
return (dx,) |
|
|
return (dx,) |
|
|
|
|
|
|
|
|
return bprop |
|
|
return bprop |
|
|
|