|
|
|
@@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP |
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
// add fusion flag |
|
|
|
MS_EXCEPTION_IF_NULL(allgather); |
|
|
|
// add fusion flag |
|
|
|
AddCommOpFusionType(allgather, parameter); |
|
|
|
// add gradients mean |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); |
|
|
|
bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); |
|
|
|
attrs["mean_flag"] = MakeValue<bool>(mean_flag); |
|
|
|
prim->SetAttrs(attrs); |
|
|
|
} |
|
|
|
|
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, |
|
|
|
|