Browse Source

enable_allgather_fusion

tags/v1.1.0
Ziyan 5 years ago
parent
commit
e29f5c96cb
2 changed files with 20 additions and 7 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +19
    -6
      mindspore/ops/_grad/grad_comm_ops.py

+ 1
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1408,7 +1408,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs(); auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend // enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue<int64_t>(0);
attrs["fusion"] = MakeValue<int64_t>(1);
prim->SetAttrs(attrs); prim->SetAttrs(attrs);
} }




+ 19
- 6
mindspore/ops/_grad/grad_comm_ops.py View File

@@ -16,6 +16,7 @@
"""Generate bprop for comm ops""" """Generate bprop for comm ops"""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size
from .. import operations as P from .. import operations as P
from ...common.tensor import RowTensor from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
@@ -116,15 +117,27 @@ def get_bprop_broad_cast(self):
@bprop_getters.register(AllGather) @bprop_getters.register(AllGather)
def get_bprop_all_gather(self): def get_bprop_all_gather(self):
"""Generate bprop for AllGather""" """Generate bprop for AllGather"""
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
fusion = self.get_attr_dict()["fusion"] fusion = self.get_attr_dict()["fusion"]
all_gather_grad.add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_gather_grad.set_prim_instance_name(instance_name)
if fusion == 0:
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name:
instance_name = "grad_" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
else:
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)


def bprop(x, out, dout): def bprop(x, out, dout):
dx = all_gather_grad(dout)
if fusion == 0:
dx = reduce_scatter(dout)
else:
grad = all_reduce(dout)
dx = split(grad)[rank]
return (dx,) return (dx,)


return bprop return bprop


Loading…
Cancel
Save