Browse Source

!10743 enable gradients mean in opt shard

From: @gong_zi_yan
Reviewed-by: @stsuteng,@yao_yf,@kisnwang
Signed-off-by: @stsuteng
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
b07dd76246
2 changed files with 12 additions and 1 deletions
  1. +8
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +4
    -0
      mindspore/ops/_grad/grad_comm_ops.py

+ 8
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
// add fusion flag
MS_EXCEPTION_IF_NULL(allgather);
// add fusion flag
AddCommOpFusionType(allgather, parameter);
// add gradients mean
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
}

static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,


+ 4
- 0
mindspore/ops/_grad/grad_comm_ops.py View File

@@ -134,6 +134,8 @@ def get_bprop_all_gather(self):
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1/self.rank_size

def bprop(x, out, dout):
if fusion == 0:
@@ -141,6 +143,8 @@ def get_bprop_all_gather(self):
else:
grad = all_reduce(dout)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
return (dx,)

return bprop


Loading…
Cancel
Save